You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

97 lines
2.7 KiB

3 years ago
# !/usr/bin/env python3
# -*- encoding : utf-8 -*-
# @Filename : chat.py
# @Software : VSCode
# @Datetime : 2023/03/23 11:25:10
# @Author : leo liu
# @Version : 1.0
# @Description :
from datetime import date
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import and_
from typing import Any
from extensions.curd_base import CRUDBase
from models import ChatHistory, ChatCountDay
from ..schemas import chat_schema
from models import User, ChatCountDay, Settings
class CRUDChat(CRUDBase[ChatHistory, ChatCountDay, chat_schema.Chat]):
'''
获取openai_key
'''
def get_openai_keys(self, db: Session) -> Any:
openai_keys = []
settings = db.query(Settings).filter(Settings.key == 'openai_key')
if settings:
for setting in settings:
openai_keys.append(setting.value)
return openai_keys
'''
获取默认max_tokens
'''
def get_default_max_tokens(self, db:Session) -> int:
default_max_tokens = 2000
setting = db.query(Settings).filter(Settings.key == 'default_max_tokens').first()
if setting:
if (str.isdigit(setting.value)):
limit = int(setting.value)
if limit > 0 and limit <= 2048:
default_max_tokens = limit
return default_max_tokens
'''
获取当前用户当天的提问次数
'''
def get_chat_count(self, db: Session, user: User) -> int:
q_date = date.today()
chat_count_day = db.query(ChatCountDay).filter(and_(ChatCountDay.user_id == user.user_id, ChatCountDay.q_time == q_date)).first()
if not chat_count_day:
return -1
return chat_count_day.q_times
'''
更新当前用户当天的提问次数
'''
def update_chat_count(self, db: Session, user: User):
q_date = date.today()
chat_count_day = db.query(ChatCountDay).filter(and_(ChatCountDay.user_id == user.user_id, ChatCountDay.q_time == q_date)).first()
if not chat_count_day:
chat_count_day = ChatCountDay(
user_id = user.user_id,
q_time = q_date,
q_times = 1
)
user.chat_count_day.append(chat_count_day)
else:
chat_count_day.q_times = chat_count_day.q_times + 1
db.commit()
db.refresh(user)
'''
记录用户会话
'''
def new_chat(self, db: Session, user: User, q: str, a: str):
chat_history = ChatHistory(
user_id = user.user_id,
username = user.username,
q_content = q,
a_content = a
)
user.chat_history.append(chat_history)
db.commit()
db.refresh(user)