Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
FROM python:3.12-slim
FROM python:3.12

WORKDIR /app

COPY src/ ./

RUN pip install --upgrade pip setuptools wheel
RUN pip install --no-cache-dir -r requirements.txt

CMD ["python", "main.py"]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"sniffio~=1.3.1",
"tqdm~=4.67.1",
"typing-extensions~=4.12.2",
"isodate~=0.7.2",
]
readme = "README.md"
requires-python = ">= 3.10"
Expand Down
32 changes: 32 additions & 0 deletions src/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import time

import psycopg2
import logging

logger = logging.getLogger(__name__)

DB_USER = os.environ["DB_USER"]
DB_PASSWORD = os.environ["DB_PASSWORD"]
DB_NAME = os.environ["DB_NAME"]
DB_HOST = os.environ["DB_HOST"]

# This is needed because the DB container takes a longer time to start,
# so the DB may not be available in the beginning.
for _ in range(5):
try:
con = psycopg2.connect(
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
host=DB_HOST,
port=5432,
)
logger.info("Connection successful!")
break # success! no need to repeat
except psycopg2.OperationalError as e:
logger.error("Error while connecting to the database:", e)
time.sleep(5)
else:
logger.error("Can't connect to the database. Abort.")
exit(1)
13 changes: 0 additions & 13 deletions src/evaluate.py

This file was deleted.

22 changes: 14 additions & 8 deletions src/handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import json
import os
import uuid
from typing import Any
Expand All @@ -7,22 +8,26 @@
from telegram import Update, Message, PhotoSize
from telegram.ext import ExtBot, CallbackContext
from llm import ask_ai
from logger import logger
import logging

logger = logging.getLogger(__name__)

BOT_NAME = "ButlerBot"
BOT_USER_ID = 0
BOT_MESSAGE_ID = 0
no_reply_token = "-"
SYSTEM_PROMPT = f"""
Each message in the conversation below is prefixed with the username and their unique
identifier, like this: "username (123456789): MESSAGE...". '
Each message in the conversation below is sent in json like this
{{"user_name": user_name, "user_id": user_id, "chat_id": chat_id, "message": message}}.
You play the role of the user called {BOT_NAME}, or simply Bot;
your username and unique identifier are {BOT_NAME} and 0.
You are observing the users' conversation and normally you do not interfere
unless you are explicitly called by name (e.g., 'bot,' '{BOT_NAME},' etc.).
Explicit mentions include cases where your name or identifier appears anywhere in the message.
If you are not explicitly addressed, always respond with {no_reply_token}.
When answering, don't use LaTeX.
When setting/editing reminder, you are not allowed to answer from your own knowledge.
You must call the appropriate tool and return its output.
"""
DB_BLOB_DIR = Path(os.environ["DB_BLOB_DIR"])
DB_BLOB_DIR.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -128,19 +133,20 @@ async def generate_response(chat_id: int, con: psycopg2.connect) -> str:
messages.append(
{
"role": "assistant" if user_id == 0 else "user",
"content": f"{user_name} ({user_id}): {message}",
"content": json.dumps(
{"user_name": user_name, "user_id": user_id, "chat_id": chat_id, "message": message}
),
}
)
logger.info("all messages: %s", messages)
response = await ask_ai(messages)

response = response.removeprefix(f"{BOT_NAME} ({BOT_USER_ID}): ")
response = json.loads(response).get("message", "no message received.")
cur.execute(
"""
INSERT INTO user_message (chat_id, user_id, message)
VALUES (%s, 0, %s)
VALUES (%s, %s, %s)
""",
(chat_id, response),
(chat_id, BOT_USER_ID, response),
)
con.commit()
return response
Expand Down
32 changes: 25 additions & 7 deletions src/llm.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
import asyncio
import json
from typing import Any
import os
import openai
from openai import OpenAI
from openai.types.chat import ChatCompletion

from evaluate import evaluate
from logger import logger
from tool_function import evaluate, set_reminder
import logging

logger = logging.getLogger(__name__)

# todo: make this a class

OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
OPENAI_MODEL = "gpt-4o"

TOOL_DEF = json.load(open("tools.json"))

client = OpenAI(api_key=OPENAI_API_KEY)


async def ask_ai(messages: list) -> str:
async def ask_ai(messages: list[Any]) -> str:
loop = asyncio.get_running_loop() # gain access to the scheduler

def runs_in_background_thread() -> ChatCompletion:
Expand All @@ -25,7 +30,7 @@ def runs_in_background_thread() -> ChatCompletion:
completion = client.chat.completions.create(
model=OPENAI_MODEL,
messages=messages,
tools=json.load(open("tools.json")),
tools=TOOL_DEF,
)
except openai.BadRequestError as e:
logger.error(f"OpenAI API error: {e}")
Expand All @@ -40,10 +45,23 @@ def runs_in_background_thread() -> ChatCompletion:

while message.tool_calls:
tool_call = message.tool_calls[0]
logger.info(f"tool call {tool_call}")
function = tool_call.function.name
answer = "no function is called."
if function == "set_reminder":
arguments = json.loads(tool_call.function.arguments)
chat_id, action, duration, deadline = (
arguments["chat_id"],
arguments["action"],
arguments.get("duration", None),
arguments.get("deadline", None),
)
answer = set_reminder(chat_id, action, deadline, duration)
elif function == "evaluate":
arguments = json.loads(tool_call.function.arguments)
expression = arguments["expression"]
answer = evaluate(expression)
logger.info(f"Tool call message: {message}")
arguments = json.loads(tool_call.function.arguments)
expression = arguments["expression"]
answer = evaluate(expression)
function_call_result_message = {
"role": "tool",
"content": json.dumps({"result": answer}),
Expand Down
8 changes: 0 additions & 8 deletions src/logger.py

This file was deleted.

90 changes: 62 additions & 28 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
# Copyright Song Meo <songmeo@pm.me>

import asyncio
import time
from datetime import datetime, timedelta, timezone
from typing import Any
import psycopg2
import os
from db import con
from dotenv import load_dotenv
import telegram
from telegram import Update, error
Expand All @@ -21,18 +20,56 @@
)
from handler import store_message, generate_response, help_command
from handler import BOT_NAME, BOT_USER_ID
from logger import logger
import logging

logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
# set higher logging level for httpx to avoid all GET and POST requests being logged
logging.getLogger("httpx").setLevel(logging.WARNING)

logger = logging.getLogger(__name__)

load_dotenv()

DB_USER = os.environ["DB_USER"]
DB_PASSWORD = os.environ["DB_PASSWORD"]
DB_NAME = os.environ["DB_NAME"]
DB_HOST = os.environ["DB_HOST"]
TOKEN = os.environ["TOKEN"]


async def generate_response_loop(con: psycopg2.connect) -> None:
async def send_reminder_loop() -> None:
# TODO FIXME add logging for single iteration failure and total loop failure.
while True:
cur = con.cursor()
cur.execute(
"SELECT chat_id, action, deadline FROM user_reminder WHERE deadline <= %s AND is_notified = %s",
(
datetime.now().isoformat(),
False,
),
)
reminders = cur.fetchall()
for r in reminders:
chat_id, action, deadline = r
bot = telegram.Bot(token=TOKEN)

# This is a simplified solution; in the future, we should ask the LLM to process reminder events.
# In response, the LLM can invoke another tool, like message a specific user (doesn't have to be
# the user that created the reminder), or do this and that.
message = f"This is a reminder to {action} at {deadline}."

cur.execute(
"""
INSERT INTO user_message (chat_id, user_id, message)
VALUES (%s, %s, %s)
""",
(chat_id, BOT_USER_ID, message),
)

await bot.send_message(chat_id=chat_id, text=message)
cur.execute("UPDATE user_reminder SET is_notified = TRUE WHERE action = %s", (action,))
con.commit()

await asyncio.sleep(60)


async def generate_response_loop() -> None:
while True:
cur = con.cursor()
cur.execute("SELECT chat_id FROM user_message")
Expand Down Expand Up @@ -66,25 +103,7 @@ async def generate_response_loop(con: psycopg2.connect) -> None:

def main() -> None:
application = Application.builder().token(TOKEN).build()

for _ in range(5):
try:
con = psycopg2.connect(
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
host=DB_HOST,
port=5432,
)
cur = con.cursor()
logger.info("Connection successful!")
break # success! no need to repeat
except psycopg2.OperationalError as e:
logger.error("Error while connecting to the database:", e)
time.sleep(5)
else:
logger.error("Can't connect to the database. Abort.")
exit(1)
cur = con.cursor()

cur.execute(
"""
Expand Down Expand Up @@ -126,6 +145,20 @@ def main() -> None:
)
"""
)

cur.execute(
"""
CREATE TABLE IF NOT EXISTS user_reminder (
id SERIAL PRIMARY KEY, -- SERIAL handles auto-incrementing
chat_id BIGINT NOT NULL,
action TEXT NOT NULL,
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
deadline TIMESTAMPTZ NOT NULL,
is_notified BOOLEAN DEFAULT FALSE
)
"""
)

con.commit()

async def sticker_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
Expand Down Expand Up @@ -161,7 +194,8 @@ async def message_handler_proxy(update: Update, context: ContextTypes.DEFAULT_TY

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(generate_response_loop(con))
loop.create_task(generate_response_loop())
loop.create_task(send_reminder_loop())

# Run the bot until the user presses Ctrl-C
application.run_polling(allowed_updates=Update.ALL_TYPES)
Expand Down
1 change: 1 addition & 0 deletions src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ h11~=0.14.0
httpcore~=1.0.7
httpx~=0.28.1
idna~=3.10
isodate~=0.7.2
jiter~=0.8.0
openai~=1.57.1
psycopg2-binary~=2.9.7
Expand Down
38 changes: 38 additions & 0 deletions src/tool_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from db import con
from datetime import datetime, timezone
import isodate


def set_reminder(chat_id: int, action: str, deadline: str, duration: str) -> str:
today = datetime.now(timezone.utc)
if deadline:
if (datetime.fromisoformat(deadline) - today).seconds < -60:
return "Sorry deadline is past."
elif duration:
td = isodate.parse_duration(duration)
deadline = today + td
else:
return "You must define deadline or duration."
cur = con.cursor()

cur.execute(
"INSERT INTO user_reminder (chat_id, action, deadline) VALUES (%s, %s, %s)", (chat_id, action, deadline)
)
con.commit()

return f"A reminder for '{action}' is set on {deadline}."


def evaluate(expression: str) -> str:
try:
ans = eval(expression)
except Exception as error:
return str(type(error).__name__)
return str(ans)


def test_evaluate() -> None:
assert evaluate("123 + 456") == str(123 + 456)
assert evaluate("455 +_/ 342") == "NameError"
assert evaluate("455 +_( 342") == "SyntaxError"
assert evaluate("455 / 0") == "ZeroDivisionError"
Loading