You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					97 lines
				
				2.7 KiB
			
		
		
			
		
	
	
					97 lines
				
				2.7 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								# !/usr/bin/env python3
							 | 
						||
| 
								 | 
							
								# -*- encoding : utf-8 -*-
							 | 
						||
| 
								 | 
							
								# @Filename    : chat.py
							 | 
						||
| 
								 | 
							
								# @Software    : VSCode
							 | 
						||
| 
								 | 
							
								# @Datetime    : 2023/03/23 11:25:10
							 | 
						||
| 
								 | 
							
								# @Author      : leo liu
							 | 
						||
| 
								 | 
							
								# @Version     : 1.0
							 | 
						||
| 
								 | 
							
								# @Description :
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from datetime import date
							 | 
						||
| 
								 | 
							
								from sqlalchemy.orm import Session
							 | 
						||
| 
								 | 
							
								from sqlalchemy.sql.expression import and_
							 | 
						||
| 
								 | 
							
								from typing import Any
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from extensions.curd_base import CRUDBase
							 | 
						||
| 
								 | 
							
								from models import ChatHistory, ChatCountDay
							 | 
						||
| 
								 | 
							
								from ..schemas import chat_schema
							 | 
						||
| 
								 | 
							
								from models import User, ChatCountDay, Settings
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class CRUDChat(CRUDBase[ChatHistory, ChatCountDay, chat_schema.Chat]):
							 | 
						||
| 
								 | 
							
								    '''
							 | 
						||
| 
								 | 
							
								    获取openai_key
							 | 
						||
| 
								 | 
							
								    '''
							 | 
						||
| 
								 | 
							
								    def get_openai_keys(self, db: Session) -> Any:
							 | 
						||
| 
								 | 
							
								        openai_keys = []
							 | 
						||
| 
								 | 
							
								        settings = db.query(Settings).filter(Settings.key == 'openai_key')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if settings:
							 | 
						||
| 
								 | 
							
								            for setting in settings:
							 | 
						||
| 
								 | 
							
								                openai_keys.append(setting.value)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return openai_keys
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    '''
							 | 
						||
| 
								 | 
							
								    获取默认max_tokens
							 | 
						||
| 
								 | 
							
								    '''
							 | 
						||
| 
								 | 
							
								    def get_default_max_tokens(self, db:Session) -> int:
							 | 
						||
| 
								 | 
							
								        default_max_tokens = 2000
							 | 
						||
| 
								 | 
							
								        setting = db.query(Settings).filter(Settings.key == 'default_max_tokens').first()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if setting:
							 | 
						||
| 
								 | 
							
								            if (str.isdigit(setting.value)):
							 | 
						||
| 
								 | 
							
								                 limit = int(setting.value)
							 | 
						||
| 
								 | 
							
								                 if limit > 0 and limit <= 2048:
							 | 
						||
| 
								 | 
							
								                    default_max_tokens = limit
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return default_max_tokens
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    '''
							 | 
						||
| 
								 | 
							
								    获取当前用户当天的提问次数
							 | 
						||
| 
								 | 
							
								    '''
							 | 
						||
| 
								 | 
							
								    def get_chat_count(self, db: Session, user: User) -> int:
							 | 
						||
| 
								 | 
							
								        q_date = date.today()
							 | 
						||
| 
								 | 
							
								        chat_count_day = db.query(ChatCountDay).filter(and_(ChatCountDay.user_id == user.user_id, ChatCountDay.q_time == q_date)).first()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not chat_count_day:
							 | 
						||
| 
								 | 
							
								            return -1
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return chat_count_day.q_times
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    '''
							 | 
						||
| 
								 | 
							
								    更新当前用户当天的提问次数
							 | 
						||
| 
								 | 
							
								    '''
							 | 
						||
| 
								 | 
							
								    def update_chat_count(self, db: Session, user: User):
							 | 
						||
| 
								 | 
							
								        q_date = date.today()
							 | 
						||
| 
								 | 
							
								        chat_count_day = db.query(ChatCountDay).filter(and_(ChatCountDay.user_id == user.user_id, ChatCountDay.q_time == q_date)).first()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not chat_count_day:
							 | 
						||
| 
								 | 
							
								            chat_count_day = ChatCountDay(
							 | 
						||
| 
								 | 
							
								                user_id = user.user_id,
							 | 
						||
| 
								 | 
							
								                q_time = q_date,
							 | 
						||
| 
								 | 
							
								                q_times = 1
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            user.chat_count_day.append(chat_count_day)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            chat_count_day.q_times = chat_count_day.q_times + 1
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        db.commit()
							 | 
						||
| 
								 | 
							
								        db.refresh(user)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    '''
							 | 
						||
| 
								 | 
							
								    记录用户会话
							 | 
						||
| 
								 | 
							
								    '''
							 | 
						||
| 
								 | 
							
								    def new_chat(self, db: Session, user: User, q: str, a: str):
							 | 
						||
| 
								 | 
							
								        chat_history = ChatHistory(
							 | 
						||
| 
								 | 
							
								            user_id = user.user_id,
							 | 
						||
| 
								 | 
							
								            username = user.username,
							 | 
						||
| 
								 | 
							
								            q_content = q,
							 | 
						||
| 
								 | 
							
								            a_content = a
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        user.chat_history.append(chat_history)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        db.commit()
							 | 
						||
| 
								 | 
							
								        db.refresh(user)
							 |