You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
2.7 KiB
97 lines
2.7 KiB
# !/usr/bin/env python3
|
|
# -*- encoding : utf-8 -*-
|
|
# @Filename : chat.py
|
|
# @Software : VSCode
|
|
# @Datetime : 2023/03/23 11:25:10
|
|
# @Author : leo liu
|
|
# @Version : 1.0
|
|
# @Description :
|
|
|
|
from datetime import date
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.sql.expression import and_
|
|
from typing import Any
|
|
|
|
from extensions.curd_base import CRUDBase
|
|
from models import ChatHistory, ChatCountDay
|
|
from ..schemas import chat_schema
|
|
from models import User, ChatCountDay, Settings
|
|
|
|
class CRUDChat(CRUDBase[ChatHistory, ChatCountDay, chat_schema.Chat]):
|
|
'''
|
|
获取openai_key
|
|
'''
|
|
def get_openai_keys(self, db: Session) -> Any:
|
|
openai_keys = []
|
|
settings = db.query(Settings).filter(Settings.key == 'openai_key')
|
|
|
|
if settings:
|
|
for setting in settings:
|
|
openai_keys.append(setting.value)
|
|
|
|
return openai_keys
|
|
|
|
'''
|
|
获取默认max_tokens
|
|
'''
|
|
def get_default_max_tokens(self, db:Session) -> int:
|
|
default_max_tokens = 2000
|
|
setting = db.query(Settings).filter(Settings.key == 'default_max_tokens').first()
|
|
|
|
if setting:
|
|
if (str.isdigit(setting.value)):
|
|
limit = int(setting.value)
|
|
if limit > 0 and limit <= 2048:
|
|
default_max_tokens = limit
|
|
|
|
return default_max_tokens
|
|
|
|
'''
|
|
获取当前用户当天的提问次数
|
|
'''
|
|
def get_chat_count(self, db: Session, user: User) -> int:
|
|
q_date = date.today()
|
|
chat_count_day = db.query(ChatCountDay).filter(and_(ChatCountDay.user_id == user.user_id, ChatCountDay.q_time == q_date)).first()
|
|
|
|
if not chat_count_day:
|
|
return -1
|
|
|
|
return chat_count_day.q_times
|
|
|
|
'''
|
|
更新当前用户当天的提问次数
|
|
'''
|
|
def update_chat_count(self, db: Session, user: User):
|
|
q_date = date.today()
|
|
chat_count_day = db.query(ChatCountDay).filter(and_(ChatCountDay.user_id == user.user_id, ChatCountDay.q_time == q_date)).first()
|
|
|
|
if not chat_count_day:
|
|
chat_count_day = ChatCountDay(
|
|
user_id = user.user_id,
|
|
q_time = q_date,
|
|
q_times = 1
|
|
)
|
|
|
|
user.chat_count_day.append(chat_count_day)
|
|
else:
|
|
chat_count_day.q_times = chat_count_day.q_times + 1
|
|
|
|
db.commit()
|
|
db.refresh(user)
|
|
|
|
'''
|
|
记录用户会话
|
|
'''
|
|
def new_chat(self, db: Session, user: User, q: str, a: str):
|
|
chat_history = ChatHistory(
|
|
user_id = user.user_id,
|
|
username = user.username,
|
|
q_content = q,
|
|
a_content = a
|
|
)
|
|
|
|
user.chat_history.append(chat_history)
|
|
|
|
db.commit()
|
|
db.refresh(user)
|