You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

696 lines
22 KiB

3 years ago
import asyncio
import json
import platform
import sys
import threading
import warnings
from contextlib import asynccontextmanager
from json import JSONDecodeError
from typing import (
AsyncGenerator,
AsyncIterator,
Dict,
Iterator,
Optional,
Tuple,
Union,
overload,
)
from urllib.parse import urlencode, urlsplit, urlunsplit
import aiohttp
import requests
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
import openai
from openai import error, util, version
from openai.openai_response import OpenAIResponse
from openai.util import ApiType
TIMEOUT_SECS = 600
MAX_CONNECTION_RETRIES = 2
# Has one attribute per thread, 'session'.
_thread_context = threading.local()
def _build_api_url(url, query):
scheme, netloc, path, base_query, fragment = urlsplit(url)
if base_query:
query = "%s&%s" % (base_query, query)
return urlunsplit((scheme, netloc, path, query, fragment))
def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]:
"""Returns a value suitable for the 'proxies' argument to 'requests.request."""
if proxy is None:
return None
elif isinstance(proxy, str):
return {"http": proxy, "https": proxy}
elif isinstance(proxy, dict):
return proxy.copy()
else:
raise ValueError(
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
)
def _aiohttp_proxies_arg(proxy) -> Optional[str]:
"""Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request."""
if proxy is None:
return None
elif isinstance(proxy, str):
return proxy
elif isinstance(proxy, dict):
return proxy["https"] if "https" in proxy else proxy["http"]
else:
raise ValueError(
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
)
def _make_session() -> requests.Session:
if not openai.verify_ssl_certs:
warnings.warn("verify_ssl_certs is ignored; openai always verifies.")
s = requests.Session()
proxies = _requests_proxies_arg(openai.proxy)
if proxies:
s.proxies = proxies
s.mount(
"https://",
requests.adapters.HTTPAdapter(max_retries=MAX_CONNECTION_RETRIES),
)
return s
def parse_stream_helper(line: bytes) -> Optional[str]:
if line:
if line.strip() == b"data: [DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
if line.startswith(b"data: "):
line = line[len(b"data: "):]
return line.decode("utf-8")
else:
return None
return None
def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody:
_line = parse_stream_helper(line)
if _line is not None:
yield _line
async def parse_stream_async(rbody: aiohttp.StreamReader):
async for line in rbody:
_line = parse_stream_helper(line)
if _line is not None:
yield _line
class APIRequestor:
def __init__(
self,
key=None,
api_base=None,
api_type=None,
api_version=None,
organization=None,
):
self.api_base = api_base or openai.api_base
self.api_key = key or util.default_api_key()
self.api_type = (
ApiType.from_str(api_type)
if api_type
else ApiType.from_str(openai.api_type)
)
self.api_version = api_version or openai.api_version
self.organization = organization or openai.organization
@classmethod
def format_app_info(cls, info):
str = info["name"]
if info["version"]:
str += "/%s" % (info["version"],)
if info["url"]:
str += " (%s)" % (info["url"],)
return str
@overload
def request(
self,
method,
url,
params,
headers,
files,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
pass
@overload
def request(
self,
method,
url,
params=...,
headers=...,
files=...,
*,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
pass
@overload
def request(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: Literal[False] = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[OpenAIResponse, bool, str]:
pass
@overload
def request(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: bool = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
pass
def request(
self,
method,
url,
params=None,
headers=None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
result = self.request_raw(
method.lower(),
url,
params=params,
supplied_headers=headers,
files=files,
stream=stream,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key
@overload
async def arequest(
self,
method,
url,
params,
headers,
files,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
*,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: Literal[False] = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[OpenAIResponse, bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: bool = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
pass
async def arequest(
self,
method,
url,
params=None,
headers=None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
ctx = aiohttp_session()
session = await ctx.__aenter__()
try:
result = await self.arequest_raw(
method.lower(),
url,
session,
params=params,
supplied_headers=headers,
files=files,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = await self._interpret_async_response(result, stream)
except Exception:
await ctx.__aexit__(None, None, None)
raise
if got_stream:
async def wrap_resp():
assert isinstance(resp, AsyncGenerator)
try:
async for r in resp:
yield r
finally:
await ctx.__aexit__(None, None, None)
return wrap_resp(), got_stream, self.api_key
else:
await ctx.__aexit__(None, None, None)
return resp, got_stream, self.api_key
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
try:
error_data = resp["error"]
except (KeyError, TypeError):
raise error.APIError(
"Invalid response object from API: %r (HTTP response code "
"was %d)" % (rbody, rcode),
rbody,
rcode,
resp,
)
if "internal_message" in error_data:
error_data["message"] += "\n\n" + error_data["internal_message"]
util.log_info(
"OpenAI API error received",
error_code=error_data.get("code"),
error_type=error_data.get("type"),
error_message=error_data.get("message"),
error_param=error_data.get("param"),
stream_error=stream_error,
)
# Rate limits were previously coded as 400's with code 'rate_limit'
if rcode == 429:
return error.RateLimitError(
error_data.get("message"), rbody, rcode, resp, rheaders
)
elif rcode in [400, 404, 415]:
return error.InvalidRequestError(
error_data.get("message"),
error_data.get("param"),
error_data.get("code"),
rbody,
rcode,
resp,
rheaders,
)
elif rcode == 401:
return error.AuthenticationError(
error_data.get("message"), rbody, rcode, resp, rheaders
)
elif rcode == 403:
return error.PermissionError(
error_data.get("message"), rbody, rcode, resp, rheaders
)
elif rcode == 409:
return error.TryAgain(
error_data.get("message"), rbody, rcode, resp, rheaders
)
elif stream_error:
# TODO: we will soon attach status codes to stream errors
parts = [error_data.get("message"), "(Error occurred while streaming.)"]
message = " ".join([p for p in parts if p is not None])
return error.APIError(message, rbody, rcode, resp, rheaders)
else:
return error.APIError(
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}",
rbody,
rcode,
resp,
rheaders,
)
def request_headers(
self, method: str, extra, request_id: Optional[str]
) -> Dict[str, str]:
user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,)
if openai.app_info:
user_agent += " " + self.format_app_info(openai.app_info)
uname_without_node = " ".join(
v for k, v in platform.uname()._asdict().items() if k != "node"
)
ua = {
"bindings_version": version.VERSION,
"httplib": "requests",
"lang": "python",
"lang_version": platform.python_version(),
"platform": platform.platform(),
"publisher": "openai",
"uname": uname_without_node,
}
if openai.app_info:
ua["application"] = openai.app_info
headers = {
"X-OpenAI-Client-User-Agent": json.dumps(ua),
"User-Agent": user_agent,
}
headers.update(util.api_key_to_header(self.api_type, self.api_key))
if self.organization:
headers["OpenAI-Organization"] = self.organization
if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
headers["OpenAI-Version"] = self.api_version
if request_id is not None:
headers["X-Request-Id"] = request_id
if openai.debug:
headers["OpenAI-Debug"] = "true"
headers.update(extra)
return headers
def _validate_headers(
self, supplied_headers: Optional[Dict[str, str]]
) -> Dict[str, str]:
headers: Dict[str, str] = {}
if supplied_headers is None:
return headers
if not isinstance(supplied_headers, dict):
raise TypeError("Headers must be a dictionary")
for k, v in supplied_headers.items():
if not isinstance(k, str):
raise TypeError("Header keys must be strings")
if not isinstance(v, str):
raise TypeError("Header values must be strings")
headers[k] = v
# NOTE: It is possible to do more validation of the headers, but a request could always
# be made to the API manually with invalid headers, so we need to handle them server side.
return headers
def _prepare_request_raw(
self,
url,
supplied_headers,
method,
params,
files,
request_id: Optional[str],
) -> Tuple[str, Dict[str, str], Optional[bytes]]:
abs_url = "%s%s" % (self.api_base, url)
headers = self._validate_headers(supplied_headers)
data = None
if method == "get" or method == "delete":
if params:
encoded_params = urlencode(
[(k, v) for k, v in params.items() if v is not None]
)
abs_url = _build_api_url(abs_url, encoded_params)
elif method in {"post", "put"}:
if params and files:
data = params
if params and not files:
data = json.dumps(params).encode()
headers["Content-Type"] = "application/json"
else:
raise error.APIConnectionError(
"Unrecognized HTTP method %r. This may indicate a bug in the "
"OpenAI bindings. Please contact support@openai.com for "
"assistance." % (method,)
)
headers = self.request_headers(method, headers, request_id)
util.log_debug("Request to OpenAI API", method=method, path=abs_url)
util.log_debug("Post details", data=data, api_version=self.api_version)
return abs_url, headers, data
def request_raw(
self,
method,
url,
*,
params=None,
supplied_headers: Optional[Dict[str, str]] = None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> requests.Response:
abs_url, headers, data = self._prepare_request_raw(
url, supplied_headers, method, params, files, request_id
)
if not hasattr(_thread_context, "session"):
_thread_context.session = _make_session()
try:
result = _thread_context.session.request(
method,
abs_url,
headers=headers,
data=data,
files=files,
stream=stream,
timeout=request_timeout if request_timeout else TIMEOUT_SECS,
)
except requests.exceptions.Timeout as e:
raise error.Timeout("Request timed out: {}".format(e)) from e
except requests.exceptions.RequestException as e:
raise error.APIConnectionError(
"Error communicating with OpenAI: {}".format(e)
) from e
util.log_debug(
"OpenAI API response",
path=abs_url,
response_code=result.status_code,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
request_id=result.headers.get("X-Request-Id"),
)
# Don't read the whole stream for debug logging unless necessary.
if openai.log == "debug":
util.log_debug(
"API response body", body=result.content, headers=result.headers
)
return result
async def arequest_raw(
self,
method,
url,
session,
*,
params=None,
supplied_headers: Optional[Dict[str, str]] = None,
files=None,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> aiohttp.ClientResponse:
abs_url, headers, data = self._prepare_request_raw(
url, supplied_headers, method, params, files, request_id
)
if isinstance(request_timeout, tuple):
timeout = aiohttp.ClientTimeout(
connect=request_timeout[0],
total=request_timeout[1],
)
else:
timeout = aiohttp.ClientTimeout(
total=request_timeout if request_timeout else TIMEOUT_SECS
)
if files:
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
# For now we use the private `requests` method that is known to have worked so far.
data, content_type = requests.models.RequestEncodingMixin._encode_files( # type: ignore
files, data
)
headers["Content-Type"] = content_type
request_kwargs = {
"method": method,
"url": abs_url,
"headers": headers,
"data": data,
"proxy": _aiohttp_proxies_arg(openai.proxy),
"timeout": timeout,
}
try:
result = await session.request(**request_kwargs)
util.log_info(
"OpenAI API response",
path=abs_url,
response_code=result.status,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
request_id=result.headers.get("X-Request-Id"),
)
# Don't read the whole stream for debug logging unless necessary.
if openai.log == "debug":
util.log_debug(
"API response body", body=result.content, headers=result.headers
)
return result
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise error.Timeout("Request timed out") from e
except aiohttp.ClientError as e:
raise error.APIConnectionError("Error communicating with OpenAI") from e
def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
return (
self._interpret_response_line(
line, result.status_code, result.headers, stream=True
)
for line in parse_stream(result.iter_lines())
), True
else:
return (
self._interpret_response_line(
result.content.decode("utf-8"),
result.status_code,
result.headers,
stream=False,
),
False,
)
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
return (
self._interpret_response_line(
line, result.status, result.headers, stream=True
)
async for line in parse_stream_async(result.content)
), True
else:
try:
await result.read()
except aiohttp.ClientError as e:
util.log_warn(e, body=result.content)
return (
self._interpret_response_line(
(await result.read()).decode("utf-8"),
result.status,
result.headers,
stream=False,
),
False,
)
def _interpret_response_line(
self, rbody: str, rcode: int, rheaders, stream: bool
) -> OpenAIResponse:
# HTTP 204 response code does not have any content in the body.
if rcode == 204:
return OpenAIResponse(None, rheaders)
if rcode == 503:
raise error.ServiceUnavailableError(
"The server is overloaded or not ready yet.",
rbody,
rcode,
headers=rheaders,
)
try:
if 'text/plain' in rheaders.get('Content-Type'):
data = rbody
else:
data = json.loads(rbody)
except (JSONDecodeError, UnicodeDecodeError) as e:
raise error.APIError(
f"HTTP code {rcode} from API ({rbody})", rbody, rcode, headers=rheaders
) from e
resp = OpenAIResponse(data, rheaders)
# In the future, we might add a "status" parameter to errors
# to better handle the "error while streaming" case.
stream_error = stream and "error" in resp.data
if stream_error or not 200 <= rcode < 300:
raise self.handle_error_response(
rbody, rcode, resp.data, rheaders, stream_error=stream_error
)
return resp
@asynccontextmanager
async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
user_set_session = openai.aiosession.get()
if user_set_session:
yield user_set_session
else:
async with aiohttp.ClientSession() as session:
yield session