|
|
# !/usr/bin/env python3
|
|
|
# -*- encoding : utf-8 -*-
|
|
|
# @Filename : user.py
|
|
|
# @Software : VSCode
|
|
|
# @Datetime : 2021/11/04 21:25:44
|
|
|
# @Author : leo liu
|
|
|
# @Version : 1.0
|
|
|
# @Description :
|
|
|
|
|
|
from datetime import datetime
|
|
|
from typing import Optional
|
|
|
from loguru import logger
|
|
|
from pydantic.types import conint
|
|
|
from sqlalchemy import func, desc
|
|
|
from sqlalchemy.orm import Session
|
|
|
from sqlalchemy.sql.expression import and_, or_, join
|
|
|
|
|
|
from core.security import get_password_hash, verify_password
|
|
|
from extensions.curd_base import CRUDBase
|
|
|
from models import User, Role, LoginHistory
|
|
|
|
|
|
from .role import crud_role
|
|
|
from ..schemas import user_schema
|
|
|
|
|
|
class CRUDUser(CRUDBase[User, user_schema.UserCreate, user_schema.UserUpdate]):
|
|
|
@staticmethod
|
|
|
def query_all(db: Session, *, col_val: str = "", is_delete: int = 0, role: str = "", page: int = 1, page_size: conint(le=50) = 20, order_by: str = "id", is_desc: bool = False) -> dict:
|
|
|
"""
|
|
|
查询用户列表
|
|
|
:param db:
|
|
|
:param page:
|
|
|
:param page_size:
|
|
|
:return:
|
|
|
"""
|
|
|
|
|
|
temp_page = (page - 1) * page_size
|
|
|
search = None
|
|
|
order_type = None
|
|
|
|
|
|
# 处理多个字段的模糊查询:username、nickname、email
|
|
|
if col_val:
|
|
|
search = and_(User.is_delete == is_delete, or_(User.username.like('%' + col_val + "%"), User.nickname.like('%' + col_val + "%"), User.email.like('%' + col_val + "%")))
|
|
|
else:
|
|
|
search = User.is_delete == is_delete
|
|
|
|
|
|
# 处理角色查询:role
|
|
|
if role:
|
|
|
search = and_(search, Role.role_id == role)
|
|
|
|
|
|
# 处理排序方式
|
|
|
# 利用属性名反射类属性
|
|
|
if is_desc:
|
|
|
order_type = desc(User.getAttrFromName(User, order_by))
|
|
|
else:
|
|
|
order_type = User.getAttrFromName(User, order_by)
|
|
|
|
|
|
# 查询数量
|
|
|
total = db.query(func.count(User.user_id)).outerjoin(User.roles).filter(search).scalar()
|
|
|
|
|
|
# 查询结果集
|
|
|
if page > 0: # page > 0 分页查询
|
|
|
query_obj = db.query(User).outerjoin(User.roles).filter(search).order_by(order_type).offset(temp_page).limit(page_size).all()
|
|
|
else: # page <= 0 全量查询
|
|
|
query_obj = db.query(User).outerjoin(User.roles).filter(search).order_by(order_type).all()
|
|
|
|
|
|
items = [{
|
|
|
"id": obj.id,
|
|
|
"user_id": obj.user_id,
|
|
|
"username": obj.username,
|
|
|
"nickname": obj.nickname,
|
|
|
"email": obj.email,
|
|
|
"avatar": obj.avatar,
|
|
|
"phone": obj.phone,
|
|
|
"gender": obj.gender,
|
|
|
"register_time": obj.register_time,
|
|
|
"last_login_time": obj.last_login_time,
|
|
|
"last_login_ip": obj.last_login_ip,
|
|
|
"wechat_openid": obj.wechat_openid,
|
|
|
"country": obj.country,
|
|
|
"province": obj.province,
|
|
|
"city": obj.city,
|
|
|
"roles": obj.roles,
|
|
|
"is_delete": obj.is_delete,
|
|
|
"update_time": obj.update_time
|
|
|
} for obj in query_obj]
|
|
|
|
|
|
return {
|
|
|
"items": items,
|
|
|
"total": total
|
|
|
}
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
def get_by_email(db: Session, *, email: str) -> Optional[User]:
|
|
|
"""
|
|
|
通过email获取用户
|
|
|
参数里面的* 表示 后面调用的时候 要用指定参数的方法调用
|
|
|
正确调用方式
|
|
|
curd_user.get_by_email(db, email="xxx")
|
|
|
错误调用方式
|
|
|
curd_user.get_by_email(db, "xxx")
|
|
|
:param db:
|
|
|
:param email:
|
|
|
:return:
|
|
|
"""
|
|
|
return db.query(User).filter(User.email == email).first()
|
|
|
|
|
|
@staticmethod
|
|
|
def get_by_username(db: Session, *, username: str) -> Optional[User]:
|
|
|
"""
|
|
|
通过用户名获取用户
|
|
|
"""
|
|
|
return db.query(User).filter(User.username == username).first()
|
|
|
|
|
|
def authenticate(self, db: Session, *, username: str, password: str, ip: str) -> Optional[User]:
|
|
|
user = self.get_by_username(db, username=username)
|
|
|
if not user:
|
|
|
return None
|
|
|
if not verify_password(password, user.hashed_password):
|
|
|
return None
|
|
|
|
|
|
user.last_login_time = datetime.now()
|
|
|
user.last_login_ip = ip
|
|
|
|
|
|
login_history = LoginHistory(
|
|
|
user_id = user.user_id,
|
|
|
username = user.username,
|
|
|
login_time = user.last_login_time,
|
|
|
login_ipv4 = user.last_login_ip
|
|
|
)
|
|
|
|
|
|
user.login_histories.append(login_history)
|
|
|
|
|
|
db.commit()
|
|
|
db.refresh(user)
|
|
|
|
|
|
return user
|
|
|
|
|
|
def authenticate_by_email(self, db: Session, *, email: str, password: str) -> Optional[User]:
|
|
|
user = self.get_by_email(db, email=email)
|
|
|
if not user:
|
|
|
return None
|
|
|
if not verify_password(password, user.hashed_password):
|
|
|
return None
|
|
|
return user
|
|
|
|
|
|
@staticmethod
|
|
|
def is_active(user: User) -> bool:
|
|
|
return user.is_active == 1
|
|
|
|
|
|
@staticmethod
|
|
|
def is_delete(user: User) -> bool:
|
|
|
return user.is_delete == 1
|
|
|
|
|
|
@staticmethod
|
|
|
def is_exist(db: Session, user_in: user_schema.UserCreate) -> bool:
|
|
|
user = db.query(User).filter(User.username == user_in.username).first()
|
|
|
|
|
|
if user:
|
|
|
return True
|
|
|
else:
|
|
|
return False
|
|
|
|
|
|
@staticmethod
|
|
|
def create(db: Session, *, obj_in: user_schema.UserCreate) -> User:
|
|
|
db_user = User(
|
|
|
username=obj_in.username,
|
|
|
nickname=obj_in.nickname,
|
|
|
email=obj_in.email,
|
|
|
hashed_password=get_password_hash(obj_in.password),
|
|
|
q_limit_day=obj_in.q_limit_day,
|
|
|
tokens_limit=obj_in.tokens_limit,
|
|
|
# avatar=obj_in.avatar,
|
|
|
is_active=obj_in.is_active
|
|
|
)
|
|
|
|
|
|
for role_id in obj_in.roles:
|
|
|
role = crud_role.get_by_id(db, role_id=role_id)
|
|
|
|
|
|
if (role):
|
|
|
db_user.roles.append(role)
|
|
|
|
|
|
db.add(db_user)
|
|
|
db.commit()
|
|
|
db.refresh(db_user)
|
|
|
return db_user
|
|
|
|
|
|
@staticmethod
|
|
|
def update(db: Session, *, obj_in: user_schema.UserUpdate) -> User:
|
|
|
db_obj = db.query(User).filter(User.username == obj_in.username).first()
|
|
|
|
|
|
if db_obj:
|
|
|
db_obj.nickname = obj_in.nickname
|
|
|
db_obj.email = obj_in.email
|
|
|
db_obj.hashed_password = get_password_hash(obj_in.password) if obj_in.password else db_obj.hashed_password
|
|
|
|
|
|
if obj_in.q_day_limit >= 0:
|
|
|
db_obj.q_day_limit = obj_in.q_day_limit
|
|
|
|
|
|
if obj_in.tokens_limit >= 0:
|
|
|
db_obj.tokens_limit = obj_in.tokens_limit
|
|
|
|
|
|
db_obj.roles.clear()
|
|
|
|
|
|
for role_id in obj_in.roles:
|
|
|
role = crud_role.get_by_id(db, role_id=role_id)
|
|
|
|
|
|
if (role):
|
|
|
db_obj.roles.append(role)
|
|
|
|
|
|
db_obj.is_active = obj_in.is_active
|
|
|
db_obj.is_delete = obj_in.is_delete
|
|
|
|
|
|
db.commit()
|
|
|
db.refresh(db_obj)
|
|
|
return db_obj |