diff --git a/Dockerfile b/Dockerfile index 78c8167..ff1198f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,10 @@ -FROM python:3.12-slim +FROM python:3.12 WORKDIR /app COPY src/ ./ +RUN pip install --upgrade pip setuptools wheel RUN pip install --no-cache-dir -r requirements.txt CMD ["python", "main.py"] diff --git a/pyproject.toml b/pyproject.toml index 42d97b5..6c4fe69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "sniffio~=1.3.1", "tqdm~=4.67.1", "typing-extensions~=4.12.2", + "isodate~=0.7.2", ] readme = "README.md" requires-python = ">= 3.10" diff --git a/src/db.py b/src/db.py new file mode 100644 index 0000000..713e64d --- /dev/null +++ b/src/db.py @@ -0,0 +1,32 @@ +import os +import time + +import psycopg2 +import logging + +logger = logging.getLogger(__name__) + +DB_USER = os.environ["DB_USER"] +DB_PASSWORD = os.environ["DB_PASSWORD"] +DB_NAME = os.environ["DB_NAME"] +DB_HOST = os.environ["DB_HOST"] + +# This is needed because the DB container takes a longer time to start, +# so the DB may not be available in the beginning. +for _ in range(5): + try: + con = psycopg2.connect( + dbname=DB_NAME, + user=DB_USER, + password=DB_PASSWORD, + host=DB_HOST, + port=5432, + ) + logger.info("Connection successful!") + break # success! no need to repeat + except psycopg2.OperationalError as e: + logger.error("Error while connecting to the database:", e) + time.sleep(5) +else: + logger.error("Can't connect to the database. Abort.") + exit(1) diff --git a/src/evaluate.py b/src/evaluate.py deleted file mode 100644 index 2ac69e2..0000000 --- a/src/evaluate.py +++ /dev/null @@ -1,13 +0,0 @@ -def evaluate(expression: str) -> str: - try: - ans = eval(expression) - except Exception as error: - return str(type(error).__name__) - return str(ans) - - -def test_evaluate() -> None: - assert evaluate("123 + 456") == str(123 + 456) - assert evaluate("455 +_/ 342") == "NameError" - assert evaluate("455 +_( 342") == "SyntaxError" - assert evaluate("455 / 0") == "ZeroDivisionError" diff --git a/src/handler.py b/src/handler.py index bb51bec..6bca922 100644 --- a/src/handler.py +++ b/src/handler.py @@ -1,4 +1,5 @@ import base64 +import json import os import uuid from typing import Any @@ -7,15 +8,17 @@ from telegram import Update, Message, PhotoSize from telegram.ext import ExtBot, CallbackContext from llm import ask_ai -from logger import logger +import logging + +logger = logging.getLogger(__name__) BOT_NAME = "ButlerBot" BOT_USER_ID = 0 BOT_MESSAGE_ID = 0 no_reply_token = "-" SYSTEM_PROMPT = f""" - Each message in the conversation below is prefixed with the username and their unique - identifier, like this: "username (123456789): MESSAGE...". ' + Each message in the conversation below is sent in json like this + {{"user_name": user_name, "user_id": user_id, "chat_id": chat_id, "message": message}}. You play the role of the user called {BOT_NAME}, or simply Bot; your username and unique identifier are {BOT_NAME} and 0. You are observing the users' conversation and normally you do not interfere @@ -23,6 +26,8 @@ Explicit mentions include cases where your name or identifier appears anywhere in the message. If you are not explicitly addressed, always respond with {no_reply_token}. When answering, don't use LaTeX. + When setting/editing reminder, you are not allowed to answer from your own knowledge. + You must call the appropriate tool and return its output. """ DB_BLOB_DIR = Path(os.environ["DB_BLOB_DIR"]) DB_BLOB_DIR.mkdir(parents=True, exist_ok=True) @@ -128,19 +133,20 @@ async def generate_response(chat_id: int, con: psycopg2.connect) -> str: messages.append( { "role": "assistant" if user_id == 0 else "user", - "content": f"{user_name} ({user_id}): {message}", + "content": json.dumps( + {"user_name": user_name, "user_id": user_id, "chat_id": chat_id, "message": message} + ), } ) logger.info("all messages: %s", messages) response = await ask_ai(messages) - - response = response.removeprefix(f"{BOT_NAME} ({BOT_USER_ID}): ") + response = json.loads(response).get("message", "no message received.") cur.execute( """ INSERT INTO user_message (chat_id, user_id, message) - VALUES (%s, 0, %s) + VALUES (%s, %s, %s) """, - (chat_id, response), + (chat_id, BOT_USER_ID, response), ) con.commit() return response diff --git a/src/llm.py b/src/llm.py index 00c640c..a5d85eb 100644 --- a/src/llm.py +++ b/src/llm.py @@ -1,22 +1,27 @@ import asyncio import json +from typing import Any import os import openai from openai import OpenAI from openai.types.chat import ChatCompletion -from evaluate import evaluate -from logger import logger +from tool_function import evaluate, set_reminder +import logging + +logger = logging.getLogger(__name__) # todo: make this a class OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] OPENAI_MODEL = "gpt-4o" +TOOL_DEF = json.load(open("tools.json")) + client = OpenAI(api_key=OPENAI_API_KEY) -async def ask_ai(messages: list) -> str: +async def ask_ai(messages: list[Any]) -> str: loop = asyncio.get_running_loop() # gain access to the scheduler def runs_in_background_thread() -> ChatCompletion: @@ -25,7 +30,7 @@ def runs_in_background_thread() -> ChatCompletion: completion = client.chat.completions.create( model=OPENAI_MODEL, messages=messages, - tools=json.load(open("tools.json")), + tools=TOOL_DEF, ) except openai.BadRequestError as e: logger.error(f"OpenAI API error: {e}") @@ -40,10 +45,23 @@ def runs_in_background_thread() -> ChatCompletion: while message.tool_calls: tool_call = message.tool_calls[0] + logger.info(f"tool call {tool_call}") + function = tool_call.function.name + answer = "no function is called." + if function == "set_reminder": + arguments = json.loads(tool_call.function.arguments) + chat_id, action, duration, deadline = ( + arguments["chat_id"], + arguments["action"], + arguments.get("duration", None), + arguments.get("deadline", None), + ) + answer = set_reminder(chat_id, action, deadline, duration) + elif function == "evaluate": + arguments = json.loads(tool_call.function.arguments) + expression = arguments["expression"] + answer = evaluate(expression) logger.info(f"Tool call message: {message}") - arguments = json.loads(tool_call.function.arguments) - expression = arguments["expression"] - answer = evaluate(expression) function_call_result_message = { "role": "tool", "content": json.dumps({"result": answer}), diff --git a/src/logger.py b/src/logger.py deleted file mode 100644 index dbcc0ab..0000000 --- a/src/logger.py +++ /dev/null @@ -1,8 +0,0 @@ -import logging - -# Enable logging -logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) -# set higher logging level for httpx to avoid all GET and POST requests being logged -logging.getLogger("httpx").setLevel(logging.WARNING) - -logger = logging.getLogger(__name__) diff --git a/src/main.py b/src/main.py index 04bfc4a..ae2063c 100755 --- a/src/main.py +++ b/src/main.py @@ -2,11 +2,10 @@ # Copyright Song Meo import asyncio -import time from datetime import datetime, timedelta, timezone from typing import Any -import psycopg2 import os +from db import con from dotenv import load_dotenv import telegram from telegram import Update, error @@ -21,18 +20,56 @@ ) from handler import store_message, generate_response, help_command from handler import BOT_NAME, BOT_USER_ID -from logger import logger +import logging + +logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) +# set higher logging level for httpx to avoid all GET and POST requests being logged +logging.getLogger("httpx").setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) load_dotenv() -DB_USER = os.environ["DB_USER"] -DB_PASSWORD = os.environ["DB_PASSWORD"] -DB_NAME = os.environ["DB_NAME"] -DB_HOST = os.environ["DB_HOST"] TOKEN = os.environ["TOKEN"] -async def generate_response_loop(con: psycopg2.connect) -> None: +async def send_reminder_loop() -> None: + # TODO FIXME add logging for single iteration failure and total loop failure. + while True: + cur = con.cursor() + cur.execute( + "SELECT chat_id, action, deadline FROM user_reminder WHERE deadline <= %s AND is_notified = %s", + ( + datetime.now().isoformat(), + False, + ), + ) + reminders = cur.fetchall() + for r in reminders: + chat_id, action, deadline = r + bot = telegram.Bot(token=TOKEN) + + # This is a simplified solution; in the future, we should ask the LLM to process reminder events. + # In response, the LLM can invoke another tool, like message a specific user (doesn't have to be + # the user that created the reminder), or do this and that. + message = f"This is a reminder to {action} at {deadline}." + + cur.execute( + """ + INSERT INTO user_message (chat_id, user_id, message) + VALUES (%s, %s, %s) + """, + (chat_id, BOT_USER_ID, message), + ) + + await bot.send_message(chat_id=chat_id, text=message) + cur.execute("UPDATE user_reminder SET is_notified = TRUE WHERE action = %s", (action,)) + con.commit() + + await asyncio.sleep(60) + + +async def generate_response_loop() -> None: while True: cur = con.cursor() cur.execute("SELECT chat_id FROM user_message") @@ -66,25 +103,7 @@ async def generate_response_loop(con: psycopg2.connect) -> None: def main() -> None: application = Application.builder().token(TOKEN).build() - - for _ in range(5): - try: - con = psycopg2.connect( - dbname=DB_NAME, - user=DB_USER, - password=DB_PASSWORD, - host=DB_HOST, - port=5432, - ) - cur = con.cursor() - logger.info("Connection successful!") - break # success! no need to repeat - except psycopg2.OperationalError as e: - logger.error("Error while connecting to the database:", e) - time.sleep(5) - else: - logger.error("Can't connect to the database. Abort.") - exit(1) + cur = con.cursor() cur.execute( """ @@ -126,6 +145,20 @@ def main() -> None: ) """ ) + + cur.execute( + """ + CREATE TABLE IF NOT EXISTS user_reminder ( + id SERIAL PRIMARY KEY, -- SERIAL handles auto-incrementing + chat_id BIGINT NOT NULL, + action TEXT NOT NULL, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + deadline TIMESTAMPTZ NOT NULL, + is_notified BOOLEAN DEFAULT FALSE + ) + """ + ) + con.commit() async def sticker_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -161,7 +194,8 @@ async def message_handler_proxy(update: Update, context: ContextTypes.DEFAULT_TY loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - loop.create_task(generate_response_loop(con)) + loop.create_task(generate_response_loop()) + loop.create_task(send_reminder_loop()) # Run the bot until the user presses Ctrl-C application.run_polling(allowed_updates=Update.ALL_TYPES) diff --git a/src/requirements.txt b/src/requirements.txt index fa116d9..b181981 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -7,6 +7,7 @@ h11~=0.14.0 httpcore~=1.0.7 httpx~=0.28.1 idna~=3.10 +isodate~=0.7.2 jiter~=0.8.0 openai~=1.57.1 psycopg2-binary~=2.9.7 diff --git a/src/tool_function.py b/src/tool_function.py new file mode 100644 index 0000000..22088fc --- /dev/null +++ b/src/tool_function.py @@ -0,0 +1,38 @@ +from db import con +from datetime import datetime, timezone +import isodate + + +def set_reminder(chat_id: int, action: str, deadline: str, duration: str) -> str: + today = datetime.now(timezone.utc) + if deadline: + if (datetime.fromisoformat(deadline) - today).seconds < -60: + return "Sorry deadline is past." + elif duration: + td = isodate.parse_duration(duration) + deadline = today + td + else: + return "You must define deadline or duration." + cur = con.cursor() + + cur.execute( + "INSERT INTO user_reminder (chat_id, action, deadline) VALUES (%s, %s, %s)", (chat_id, action, deadline) + ) + con.commit() + + return f"A reminder for '{action}' is set on {deadline}." + + +def evaluate(expression: str) -> str: + try: + ans = eval(expression) + except Exception as error: + return str(type(error).__name__) + return str(ans) + + +def test_evaluate() -> None: + assert evaluate("123 + 456") == str(123 + 456) + assert evaluate("455 +_/ 342") == "NameError" + assert evaluate("455 +_( 342") == "SyntaxError" + assert evaluate("455 / 0") == "ZeroDivisionError" diff --git a/src/tools.json b/src/tools.json index 9f1a03a..08bb80a 100644 --- a/src/tools.json +++ b/src/tools.json @@ -16,5 +16,40 @@ "required": ["expression"], "additionalProperties": false } + }, + { + "type": "function", + "function": { + "name": "set_reminder", + "description": "Set a reminder on a specific date by calling this function.", + "parameters": { + "type": "object", + "properties": { + "duration": { + "type": "string", + "description": "A relative time like 'PT10M' (ISO 8601 duration = '10 minutes from now')" + }, + "deadline": { + "type": "string", + "format": "date-time", + "description": "An absolute ISO 8601 time like '2025-04-24T15:00:00Z'" + }, + "chat_id": { + "type": "integer", + "description": "the chat id of conversation" + }, + "action": { + "type": "string", + "description": "the description of the reminder" + } + } + }, + "required": ["action", "chat_id"], + "oneOf": [ + { "required": ["duration"] }, + { "required": ["deadline"] } + ], + "additionalProperties": false + } } ]