You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
5.3 KiB
189 lines
5.3 KiB
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
import openai
|
|
|
|
OPENAI_LOG = os.environ.get("OPENAI_LOG")
|
|
|
|
logger = logging.getLogger("openai")
|
|
|
|
__all__ = [
|
|
"log_info",
|
|
"log_debug",
|
|
"log_warn",
|
|
"logfmt",
|
|
]
|
|
|
|
api_key_to_header = (
|
|
lambda api, key: {"Authorization": f"Bearer {key}"}
|
|
if api in (ApiType.OPEN_AI, ApiType.AZURE_AD)
|
|
else {"api-key": f"{key}"}
|
|
)
|
|
|
|
|
|
class ApiType(Enum):
|
|
AZURE = 1
|
|
OPEN_AI = 2
|
|
AZURE_AD = 3
|
|
|
|
@staticmethod
|
|
def from_str(label):
|
|
if label.lower() == "azure":
|
|
return ApiType.AZURE
|
|
elif label.lower() in ("azure_ad", "azuread"):
|
|
return ApiType.AZURE_AD
|
|
elif label.lower() in ("open_ai", "openai"):
|
|
return ApiType.OPEN_AI
|
|
else:
|
|
raise openai.error.InvalidAPIType(
|
|
"The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'"
|
|
)
|
|
|
|
|
|
def _console_log_level():
|
|
if openai.log in ["debug", "info"]:
|
|
return openai.log
|
|
elif OPENAI_LOG in ["debug", "info"]:
|
|
return OPENAI_LOG
|
|
else:
|
|
return None
|
|
|
|
|
|
def log_debug(message, **params):
|
|
msg = logfmt(dict(message=message, **params))
|
|
if _console_log_level() == "debug":
|
|
print(msg, file=sys.stderr)
|
|
logger.debug(msg)
|
|
|
|
|
|
def log_info(message, **params):
|
|
msg = logfmt(dict(message=message, **params))
|
|
if _console_log_level() in ["debug", "info"]:
|
|
print(msg, file=sys.stderr)
|
|
logger.info(msg)
|
|
|
|
|
|
def log_warn(message, **params):
|
|
msg = logfmt(dict(message=message, **params))
|
|
print(msg, file=sys.stderr)
|
|
logger.warn(msg)
|
|
|
|
|
|
def logfmt(props):
|
|
def fmt(key, val):
|
|
# Handle case where val is a bytes or bytesarray
|
|
if hasattr(val, "decode"):
|
|
val = val.decode("utf-8")
|
|
# Check if val is already a string to avoid re-encoding into ascii.
|
|
if not isinstance(val, str):
|
|
val = str(val)
|
|
if re.search(r"\s", val):
|
|
val = repr(val)
|
|
# key should already be a string
|
|
if re.search(r"\s", key):
|
|
key = repr(key)
|
|
return "{key}={val}".format(key=key, val=val)
|
|
|
|
return " ".join([fmt(key, val) for key, val in sorted(props.items())])
|
|
|
|
|
|
def get_object_classes():
|
|
# This is here to avoid a circular dependency
|
|
from openai.object_classes import OBJECT_CLASSES
|
|
|
|
return OBJECT_CLASSES
|
|
|
|
|
|
def convert_to_openai_object(
|
|
resp,
|
|
api_key=None,
|
|
api_version=None,
|
|
organization=None,
|
|
engine=None,
|
|
plain_old_data=False,
|
|
):
|
|
# If we get a OpenAIResponse, we'll want to return a OpenAIObject.
|
|
|
|
response_ms: Optional[int] = None
|
|
if isinstance(resp, openai.openai_response.OpenAIResponse):
|
|
organization = resp.organization
|
|
response_ms = resp.response_ms
|
|
resp = resp.data
|
|
|
|
if plain_old_data:
|
|
return resp
|
|
elif isinstance(resp, list):
|
|
return [
|
|
convert_to_openai_object(
|
|
i, api_key, api_version, organization, engine=engine
|
|
)
|
|
for i in resp
|
|
]
|
|
elif isinstance(resp, dict) and not isinstance(
|
|
resp, openai.openai_object.OpenAIObject
|
|
):
|
|
resp = resp.copy()
|
|
klass_name = resp.get("object")
|
|
if isinstance(klass_name, str):
|
|
klass = get_object_classes().get(
|
|
klass_name, openai.openai_object.OpenAIObject
|
|
)
|
|
else:
|
|
klass = openai.openai_object.OpenAIObject
|
|
|
|
return klass.construct_from(
|
|
resp,
|
|
api_key=api_key,
|
|
api_version=api_version,
|
|
organization=organization,
|
|
response_ms=response_ms,
|
|
engine=engine,
|
|
)
|
|
else:
|
|
return resp
|
|
|
|
|
|
def convert_to_dict(obj):
|
|
"""Converts a OpenAIObject back to a regular dict.
|
|
|
|
Nested OpenAIObjects are also converted back to regular dicts.
|
|
|
|
:param obj: The OpenAIObject to convert.
|
|
|
|
:returns: The OpenAIObject as a dict.
|
|
"""
|
|
if isinstance(obj, list):
|
|
return [convert_to_dict(i) for i in obj]
|
|
# This works by virtue of the fact that OpenAIObjects _are_ dicts. The dict
|
|
# comprehension returns a regular dict and recursively applies the
|
|
# conversion to each value.
|
|
elif isinstance(obj, dict):
|
|
return {k: convert_to_dict(v) for k, v in obj.items()}
|
|
else:
|
|
return obj
|
|
|
|
|
|
def merge_dicts(x, y):
|
|
z = x.copy()
|
|
z.update(y)
|
|
return z
|
|
|
|
|
|
def default_api_key() -> str:
|
|
if openai.api_key_path:
|
|
with open(openai.api_key_path, "rt") as k:
|
|
api_key = k.read().strip()
|
|
if not api_key.startswith("sk-"):
|
|
raise ValueError(f"Malformed API key in {openai.api_key_path}.")
|
|
return api_key
|
|
elif openai.api_key is not None:
|
|
return openai.api_key
|
|
else:
|
|
raise openai.error.AuthenticationError(
|
|
"No API key provided. You can set your API key in code using 'openai.api_key = <API-KEY>', or you can set the environment variable OPENAI_API_KEY=<API-KEY>). If your API key is stored in a file, you can point the openai module at it with 'openai.api_key_path = <PATH>'. You can generate API keys in the OpenAI web interface. See https://onboard.openai.com for details, or email support@openai.com if you have any questions."
|
|
)
|