@@ -453,6 +453,132 @@ def _completion_generate(self, prompts, stop):
453453 return [answer ["text" ] for answer in result ["choices" ]]
454454
455455
456+ class AzureOpenAiAgent (Agent ):
457+ """
458+ Agent that uses Azure OpenAI to generate code. See the [official
459+ documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI
460+ model on Azure
461+
462+ <Tip warning={true}>
463+
464+ The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
465+ `"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
466+
467+ </Tip>
468+
469+ Args:
470+ deployment_id (`str`):
471+ The name of the deployed Azure openAI model to use.
472+ api_key (`str`, *optional*):
473+ The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`.
474+ resource_name (`str`, *optional*):
475+ The name of your Azure OpenAI Resource. If unset, will look for the environment variable
476+ `"AZURE_OPENAI_RESOURCE_NAME"`.
477+ api_version (`str`, *optional*, default to `"2022-12-01"`):
478+ The API version to use for this agent.
479+ is_chat_mode (`bool`, *optional*):
480+ Whether you are using a completion model or a chat model (see note above, chat models won't be as
481+ efficient). Will default to `gpt` being in the `deployment_id` or not.
482+ chat_prompt_template (`str`, *optional*):
483+ Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
484+ actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
485+ `chat_prompt_template.txt` in this repo in this case.
486+ run_prompt_template (`str`, *optional*):
487+ Pass along your own prompt if you want to override the default template for the `run` method. Can be the
488+ actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
489+ `run_prompt_template.txt` in this repo in this case.
490+ additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
491+ Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
492+ one of the default tools, that default tool will be overridden.
493+
494+ Example:
495+
496+ ```py
497+ from transformers import AzureOpenAiAgent
498+
499+ agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy)
500+ agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
501+ ```
502+ """
503+
504+ def __init__ (
505+ self ,
506+ deployment_id ,
507+ api_key = None ,
508+ resource_name = None ,
509+ api_version = "2022-12-01" ,
510+ is_chat_model = None ,
511+ chat_prompt_template = None ,
512+ run_prompt_template = None ,
513+ additional_tools = None ,
514+ ):
515+ if not is_openai_available ():
516+ raise ImportError ("Using `OpenAiAgent` requires `openai`: `pip install openai`." )
517+
518+ self .deployment_id = deployment_id
519+ openai .api_type = "azure"
520+ if api_key is None :
521+ api_key = os .environ .get ("AZURE_OPENAI_API_KEY" , None )
522+ if api_key is None :
523+ raise ValueError (
524+ "You need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with "
525+ "`os.environ['AZURE_OPENAI_API_KEY'] = xxx."
526+ )
527+ else :
528+ openai .api_key = api_key
529+ if resource_name is None :
530+ resource_name = os .environ .get ("AZURE_OPENAI_RESOURCE_NAME" , None )
531+ if resource_name is None :
532+ raise ValueError (
533+ "You need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with "
534+ "`os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx."
535+ )
536+ else :
537+ openai .api_base = f"https://{ resource_name } .openai.azure.com"
538+ openai .api_version = api_version
539+
540+ if is_chat_model is None :
541+ is_chat_model = "gpt" in deployment_id .lower ()
542+ self .is_chat_model = is_chat_model
543+
544+ super ().__init__ (
545+ chat_prompt_template = chat_prompt_template ,
546+ run_prompt_template = run_prompt_template ,
547+ additional_tools = additional_tools ,
548+ )
549+
550+ def generate_many (self , prompts , stop ):
551+ if self .is_chat_model :
552+ return [self ._chat_generate (prompt , stop ) for prompt in prompts ]
553+ else :
554+ return self ._completion_generate (prompts , stop )
555+
556+ def generate_one (self , prompt , stop ):
557+ if self .is_chat_model :
558+ return self ._chat_generate (prompt , stop )
559+ else :
560+ return self ._completion_generate ([prompt ], stop )[0 ]
561+
562+ def _chat_generate (self , prompt , stop ):
563+ result = openai .ChatCompletion .create (
564+ engine = self .deployment_id ,
565+ messages = [{"role" : "user" , "content" : prompt }],
566+ temperature = 0 ,
567+ stop = stop ,
568+ )
569+ return result ["choices" ][0 ]["message" ]["content" ]
570+
571+ def _completion_generate (self , prompts , stop ):
572+ result = openai .Completion .create (
573+ engine = self .deployment_id ,
574+ prompt = prompts ,
575+ temperature = 0 ,
576+ stop = stop ,
577+ max_tokens = 200 ,
578+ )
579+ return [answer ["text" ] for answer in result ["choices" ]]
580+
581+
456582class HfAgent (Agent ):
457583 """
458584 Agent that uses an inference endpoint to generate code.
0 commit comments