-
-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* LLamaSharp integration with tests * changed test project name. added projects to solution Closes #32 --------- Co-authored-by: Konstantin S <[email protected]>
- Loading branch information
Showing
9 changed files
with
349 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 7 additions & 0 deletions
7
src/libs/Providers/LangChain.Providers.LLamaSharp/ELLamaSharpModelMode.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
namespace LangChain.Providers.LLamaSharp; | ||
|
||
public enum ELLamaSharpModelMode | ||
{ | ||
Instruction, | ||
Chat | ||
} |
30 changes: 30 additions & 0 deletions
30
src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpConfiguration.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
namespace LangChain.Providers.LLamaSharp; | ||
|
||
public class LLamaSharpConfiguration | ||
{ | ||
/// <summary> | ||
/// Path to *.bin file | ||
/// </summary> | ||
public string PathToModelFile { get; set; } | ||
|
||
/// <summary> | ||
/// Model mode. | ||
/// Chat - for conversation completion | ||
/// Instruction - for instruction execution | ||
/// </summary> | ||
public ELLamaSharpModelMode Mode { get; set; } = ELLamaSharpModelMode.Chat; | ||
|
||
/// <summary> | ||
/// Context size | ||
/// How much tokens model will remember. | ||
/// Usually 2048 for llama | ||
/// </summary> | ||
public int ContextSize { get; set; } = 512; | ||
|
||
/// <summary> | ||
/// Temperature | ||
/// The level of model's creativity | ||
/// </summary> | ||
public float Temperature { get; set; } = 0.7f; | ||
|
||
} |
145 changes: 145 additions & 0 deletions
145
src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModel.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
using System.Diagnostics; | ||
using LLama; | ||
using System.Reflection; | ||
|
||
namespace LangChain.Providers.LLamaSharp | ||
{ | ||
public class LLamaSharpModel : IChatModel | ||
{ | ||
private readonly LLamaSharpConfiguration _configuration; | ||
private readonly LLamaModel _model; | ||
public string Id { get; } | ||
public Usage TotalUsage { get; private set; } | ||
public int ContextLength =>_configuration.ContextSize; | ||
|
||
public LLamaSharpModel(LLamaSharpConfiguration configuration) | ||
{ | ||
_configuration = configuration; | ||
Id = Path.GetFileNameWithoutExtension(configuration.PathToModelFile); | ||
|
||
_model = new LLamaModel( | ||
new LLamaParams(model: configuration.PathToModelFile, | ||
interactive: true, | ||
instruct: configuration.Mode==ELLamaSharpModelMode.Instruction, | ||
temp: configuration.Temperature, | ||
n_ctx: configuration.ContextSize, repeat_penalty: 1.0f)); | ||
} | ||
|
||
string ConvertRole(MessageRole role) | ||
{ | ||
return role switch | ||
{ | ||
MessageRole.Human => "Human: ", | ||
MessageRole.Ai => "Assistant: ", | ||
MessageRole.System => "", | ||
_ => throw new NotSupportedException($"the role {role} is not supported") | ||
}; | ||
} | ||
|
||
string ConvertMessage(Message message) | ||
{ | ||
return $"{ConvertRole(message.Role)}{message.Content}"; | ||
} | ||
|
||
|
||
|
||
|
||
|
||
public Task<ChatResponse> GenerateAsync(ChatRequest request, CancellationToken cancellationToken = default) | ||
{ | ||
// take all messages except the last one | ||
// and make them a prompt | ||
var messagesArray = request.Messages.ToArray(); | ||
var prePromptMessages = messagesArray.Take(messagesArray.Length-1).ToList(); | ||
|
||
|
||
|
||
// use the last message as input | ||
string input; | ||
if(prePromptMessages.Count>0) | ||
input = request.Messages.Last().Content; | ||
else | ||
input = ConvertMessage(request.Messages.Last()); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
|
||
IEnumerable<string> response; | ||
|
||
|
||
var session = CreateSession(prePromptMessages); | ||
response = session.Chat(input + "\n"); | ||
string buf = ""; | ||
foreach (var message in response) | ||
{ | ||
buf += message; | ||
if (_configuration.Mode == ELLamaSharpModelMode.Instruction) | ||
{ | ||
if (buf.EndsWith("###")) | ||
{ | ||
buf = buf.Substring(0, buf.Length - 3); | ||
break; | ||
} | ||
|
||
} | ||
} | ||
|
||
var output = SanitizeOutput(buf); | ||
|
||
var result = request.Messages.ToList(); | ||
|
||
switch (_configuration.Mode) | ||
{ | ||
case ELLamaSharpModelMode.Chat: | ||
result.Add(output.AsAiMessage()); | ||
break; | ||
case ELLamaSharpModelMode.Instruction: | ||
result.Add(output.AsSystemMessage()); | ||
break; | ||
} | ||
|
||
|
||
watch.Stop(); | ||
|
||
// Unsupported | ||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
TotalUsage += usage; | ||
|
||
return Task.FromResult(new ChatResponse( | ||
Messages: result, | ||
Usage: usage)); | ||
} | ||
|
||
private static string SanitizeOutput(string output) | ||
{ | ||
output = output.Replace("\nHuman:", ""); | ||
output = output.Replace("Assistant:", ""); | ||
output = output.Trim(); | ||
return output; | ||
} | ||
|
||
private ChatSession<LLamaModel> CreateSession(List<Message> preprompt) | ||
{ | ||
var res = new ChatSession<LLamaModel>(_model); | ||
|
||
if (_configuration.Mode == ELLamaSharpModelMode.Chat) | ||
res = res.WithAntiprompt(new[] { "Human:" }); | ||
|
||
if (preprompt.Count > 0) | ||
{ | ||
preprompt.Add("".AsHumanMessage()); | ||
|
||
var prompt = string.Join( | ||
"\n", preprompt.Select(ConvertMessage).ToArray()); | ||
|
||
res = res | ||
.WithPrompt(prompt); | ||
} | ||
|
||
|
||
return res; | ||
} | ||
} | ||
} |
27 changes: 27 additions & 0 deletions
27
src/libs/Providers/LangChain.Providers.LLamaSharp/LangChain.Providers.LLamaSharp.csproj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFrameworks>net4.6.2;netstandard2.0;net6.0;net7.0</TargetFrameworks> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="LLamaSharp" /> | ||
<PackageReference Include="System.Net.Http" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup Label="Usings"> | ||
<Using Include="System.Net.Http" /> | ||
</ItemGroup> | ||
|
||
|
||
<PropertyGroup Label="NuGet"> | ||
<Description>LLamaSharp Chat model provider.</Description> | ||
<PackageTags>$(PackageTags);LLama;LLamaSharp;api</PackageTags> | ||
</PropertyGroup> | ||
|
||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\..\LangChain.Core\LangChain.Core.csproj" /> | ||
</ItemGroup> | ||
|
||
</Project> |
55 changes: 55 additions & 0 deletions
55
src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/LLamaSharpTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
using LangChain.Providers; | ||
using LangChain.Providers.LLamaSharp; | ||
|
||
namespace LangChain.Providers.LLamaSharp.IntegrationTests; | ||
|
||
[TestClass] | ||
public class LLamaSharpTests | ||
{ | ||
[TestMethod] | ||
#if CONTINUOUS_INTEGRATION_BUILD | ||
[Ignore] | ||
#endif | ||
public void PrepromptTest() | ||
{ | ||
var model = new LLamaSharpModel(new LLamaSharpConfiguration | ||
{ | ||
PathToModelFile = Path.Combine(Environment.ExpandEnvironmentVariables("%LLAMA_MODELS%"), "ggml-model-f32-q4_0.bin"), | ||
}); | ||
|
||
var response=model.GenerateAsync(new ChatRequest(new List<Message> | ||
{ | ||
"You are simple assistant. If human say 'Bob' then you will respond with 'Jack'.".AsSystemMessage(), | ||
"Bob".AsHumanMessage(), | ||
"Jack".AsAiMessage(), | ||
"Bob".AsHumanMessage(), | ||
"Jack".AsAiMessage(), | ||
"Bob".AsHumanMessage(), | ||
})).Result; | ||
|
||
Assert.AreEqual(response.Messages.Last().Content, "Jack"); | ||
|
||
} | ||
|
||
[TestMethod] | ||
#if CONTINUOUS_INTEGRATION_BUILD | ||
[Ignore] | ||
#endif | ||
public void InstructionTest() | ||
{ | ||
var model = new LLamaSharpModel(new LLamaSharpConfiguration | ||
{ | ||
PathToModelFile = Path.Combine(Environment.ExpandEnvironmentVariables("%LLAMA_MODELS%"), "ggml-model-f32-q4_0.bin"), | ||
Mode = ELLamaSharpModelMode.Instruction | ||
}); | ||
|
||
var response=model.GenerateAsync(new ChatRequest(new List<Message> | ||
{ | ||
"You are a calculator. You will be provided with expression. You must calculate it and print the result. Do not add any addition information.".AsSystemMessage(), | ||
"2 + 2".AsSystemMessage(), | ||
})).Result; | ||
|
||
Assert.IsTrue(response.Messages.Last().Content.Trim().Equals("4")); | ||
|
||
} | ||
} |
23 changes: 23 additions & 0 deletions
23
...viders.LLamaSharp.IntegrationTests/LangChain.Providers.LLamaSharp.IntegrationTests.csproj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFramework>net7.0</TargetFramework> | ||
<ImplicitUsings>enable</ImplicitUsings> | ||
<Nullable>enable</Nullable> | ||
|
||
<IsPackable>false</IsPackable> | ||
<IsTestProject>true</IsTestProject> | ||
<PlatformTarget>x64</PlatformTarget> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="LLamaSharp.Backend.Cpu" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\..\libs\Providers\LangChain.Providers.LLamaSharp\LangChain.Providers.LLamaSharp.csproj" /> | ||
</ItemGroup> | ||
|
||
|
||
|
||
</Project> |