You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
2.3 KiB

import io
import json
import pytest
from aiohttp import ClientSession
import openai
from openai import error
pytestmark = [pytest.mark.asyncio]
# FILE TESTS
async def test_file_upload():
result = await openai.File.acreate(
file=io.StringIO(
json.dumps({"prompt": "test file data", "completion": "tada"})
),
purpose="fine-tune",
)
assert result.purpose == "fine-tune"
assert "id" in result
result = await openai.File.aretrieve(id=result.id)
assert result.status == "uploaded"
# COMPLETION TESTS
async def test_completions():
result = await openai.Completion.acreate(
prompt="This was a test", n=5, engine="ada"
)
assert len(result.choices) == 5
async def test_completions_multiple_prompts():
result = await openai.Completion.acreate(
prompt=["This was a test", "This was another test"], n=5, engine="ada"
)
assert len(result.choices) == 10
async def test_completions_model():
result = await openai.Completion.acreate(prompt="This was a test", n=5, model="ada")
assert len(result.choices) == 5
assert result.model.startswith("ada")
async def test_timeout_raises_error():
# A query that should take awhile to return
with pytest.raises(error.Timeout):
await openai.Completion.acreate(
prompt="test" * 1000,
n=10,
model="ada",
max_tokens=100,
request_timeout=0.01,
)
async def test_timeout_does_not_error():
# A query that should be fast
await openai.Completion.acreate(
prompt="test",
model="ada",
request_timeout=10,
)
async def test_completions_stream_finishes_global_session():
async with ClientSession() as session:
openai.aiosession.set(session)
# A query that should be fast
parts = []
async for part in await openai.Completion.acreate(
prompt="test", model="ada", request_timeout=3, stream=True
):
parts.append(part)
assert len(parts) > 1
async def test_completions_stream_finishes_local_session():
# A query that should be fast
parts = []
async for part in await openai.Completion.acreate(
prompt="test", model="ada", request_timeout=3, stream=True
):
parts.append(part)
assert len(parts) > 1