You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							91 lines
						
					
					
						
							2.3 KiB
						
					
					
				
			
		
		
	
	
							91 lines
						
					
					
						
							2.3 KiB
						
					
					
				import io
 | 
						|
import json
 | 
						|
 | 
						|
import pytest
 | 
						|
from aiohttp import ClientSession
 | 
						|
 | 
						|
import openai
 | 
						|
from openai import error
 | 
						|
 | 
						|
pytestmark = [pytest.mark.asyncio]
 | 
						|
 | 
						|
 | 
						|
# FILE TESTS
 | 
						|
async def test_file_upload():
 | 
						|
    result = await openai.File.acreate(
 | 
						|
        file=io.StringIO(
 | 
						|
            json.dumps({"prompt": "test file data", "completion": "tada"})
 | 
						|
        ),
 | 
						|
        purpose="fine-tune",
 | 
						|
    )
 | 
						|
    assert result.purpose == "fine-tune"
 | 
						|
    assert "id" in result
 | 
						|
 | 
						|
    result = await openai.File.aretrieve(id=result.id)
 | 
						|
    assert result.status == "uploaded"
 | 
						|
 | 
						|
 | 
						|
# COMPLETION TESTS
 | 
						|
async def test_completions():
 | 
						|
    result = await openai.Completion.acreate(
 | 
						|
        prompt="This was a test", n=5, engine="ada"
 | 
						|
    )
 | 
						|
    assert len(result.choices) == 5
 | 
						|
 | 
						|
 | 
						|
async def test_completions_multiple_prompts():
 | 
						|
    result = await openai.Completion.acreate(
 | 
						|
        prompt=["This was a test", "This was another test"], n=5, engine="ada"
 | 
						|
    )
 | 
						|
    assert len(result.choices) == 10
 | 
						|
 | 
						|
 | 
						|
async def test_completions_model():
 | 
						|
    result = await openai.Completion.acreate(prompt="This was a test", n=5, model="ada")
 | 
						|
    assert len(result.choices) == 5
 | 
						|
    assert result.model.startswith("ada")
 | 
						|
 | 
						|
 | 
						|
async def test_timeout_raises_error():
 | 
						|
    # A query that should take awhile to return
 | 
						|
    with pytest.raises(error.Timeout):
 | 
						|
        await openai.Completion.acreate(
 | 
						|
            prompt="test" * 1000,
 | 
						|
            n=10,
 | 
						|
            model="ada",
 | 
						|
            max_tokens=100,
 | 
						|
            request_timeout=0.01,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
async def test_timeout_does_not_error():
 | 
						|
    # A query that should be fast
 | 
						|
    await openai.Completion.acreate(
 | 
						|
        prompt="test",
 | 
						|
        model="ada",
 | 
						|
        request_timeout=10,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
async def test_completions_stream_finishes_global_session():
 | 
						|
    async with ClientSession() as session:
 | 
						|
        openai.aiosession.set(session)
 | 
						|
 | 
						|
        # A query that should be fast
 | 
						|
        parts = []
 | 
						|
        async for part in await openai.Completion.acreate(
 | 
						|
            prompt="test", model="ada", request_timeout=3, stream=True
 | 
						|
        ):
 | 
						|
            parts.append(part)
 | 
						|
        assert len(parts) > 1
 | 
						|
 | 
						|
 | 
						|
async def test_completions_stream_finishes_local_session():
 | 
						|
    # A query that should be fast
 | 
						|
    parts = []
 | 
						|
    async for part in await openai.Completion.acreate(
 | 
						|
        prompt="test", model="ada", request_timeout=3, stream=True
 | 
						|
    ):
 | 
						|
        parts.append(part)
 | 
						|
    assert len(parts) > 1
 |