Skip to content

Commit

Permalink
feat: LLamaSharp integration (#44)
Browse files Browse the repository at this point in the history
* LLamaSharp integration with tests
* changed test project name. added projects to solution

Closes #32 
---------

Co-authored-by: Konstantin S <[email protected]>
  • Loading branch information
TesAnti and HavenDV authored Nov 3, 2023
1 parent 4f0c7e6 commit c776a41
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 41 deletions.
7 changes: 4 additions & 3 deletions LangChain.Sources.slnf
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
"solution": {
"path": "LangChain.sln",
"projects": [
"src\\libs\\Sources\\LangChain.Sources.Abstractions\\LangChain.Sources.Abstractions.csproj",
"src\\libs\\LangChain.Core\\LangChain.Core.csproj",
"src\\tests\\LangChain.UnitTest\\LangChain.UnitTest.csproj",
"src\\libs\\Providers\\LangChain.Providers.Abstractions\\LangChain.Providers.Abstractions.csproj"
"src\\libs\\Providers\\LangChain.Providers.LLamaSharp\\LangChain.Providers.LLamaSharp.csproj",
"src\\libs\\Providers\\LangChain.Providers.HuggingFace\\LangChain.Providers.HuggingFace.csproj",
"src\\tests\\LangChain.IntegrationTests.LLamaSharp\\LangChain.IntegrationTests.LLamaSharp.csproj",
"src\\tests\\LangChain.UnitTest\\LangChain.UnitTest.csproj"
]
}
}
93 changes: 55 additions & 38 deletions LangChain.sln

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
<PackageVersion Include="H.Resources.Generator" Version="1.5.1" />
<PackageVersion Include="HuggingFace" Version="0.2.4" />
<PackageVersion Include="LeonardoAi" Version="0.1.0" />
<PackageVersion Include="LLamaSharp" Version="0.3.0" />
<PackageVersion Include="LLamaSharp.Backend.Cpu" Version="0.3.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="4.7.0" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.7.2" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.AI.OpenAI" Version="0.15.230531.5-preview" />
Expand All @@ -32,6 +34,7 @@
<PackageVersion Include="MSTest.TestFramework" Version="3.1.1" />
<PackageVersion Include="PdfPig" Version="0.1.9-alpha-20231029-17d50" />
<PackageVersion Include="PolySharp" Version="1.13.2" />
<PackageVersion Include="System.Net.Http" Version="4.3.4" />
<PackageVersion Include="System.Text.Json" Version="7.0.3" />
<PackageVersion Include="Tiktoken" Version="1.1.3" />
<PackageVersion Include="tryAGI.OpenAI" Version="1.8.2" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace LangChain.Providers.LLamaSharp;

public enum ELLamaSharpModelMode
{
Instruction,
Chat
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
namespace LangChain.Providers.LLamaSharp;

public class LLamaSharpConfiguration
{
/// <summary>
/// Path to *.bin file
/// </summary>
public string PathToModelFile { get; set; }

/// <summary>
/// Model mode.
/// Chat - for conversation completion
/// Instruction - for instruction execution
/// </summary>
public ELLamaSharpModelMode Mode { get; set; } = ELLamaSharpModelMode.Chat;

/// <summary>
/// Context size
/// How much tokens model will remember.
/// Usually 2048 for llama
/// </summary>
public int ContextSize { get; set; } = 512;

/// <summary>
/// Temperature
/// The level of model's creativity
/// </summary>
public float Temperature { get; set; } = 0.7f;

}
145 changes: 145 additions & 0 deletions src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
using System.Diagnostics;
using LLama;
using System.Reflection;

namespace LangChain.Providers.LLamaSharp
{
public class LLamaSharpModel : IChatModel
{
private readonly LLamaSharpConfiguration _configuration;
private readonly LLamaModel _model;
public string Id { get; }
public Usage TotalUsage { get; private set; }
public int ContextLength =>_configuration.ContextSize;

public LLamaSharpModel(LLamaSharpConfiguration configuration)
{
_configuration = configuration;
Id = Path.GetFileNameWithoutExtension(configuration.PathToModelFile);

_model = new LLamaModel(
new LLamaParams(model: configuration.PathToModelFile,
interactive: true,
instruct: configuration.Mode==ELLamaSharpModelMode.Instruction,
temp: configuration.Temperature,
n_ctx: configuration.ContextSize, repeat_penalty: 1.0f));
}

string ConvertRole(MessageRole role)
{
return role switch
{
MessageRole.Human => "Human: ",
MessageRole.Ai => "Assistant: ",
MessageRole.System => "",
_ => throw new NotSupportedException($"the role {role} is not supported")
};
}

string ConvertMessage(Message message)
{
return $"{ConvertRole(message.Role)}{message.Content}";
}





public Task<ChatResponse> GenerateAsync(ChatRequest request, CancellationToken cancellationToken = default)
{
// take all messages except the last one
// and make them a prompt
var messagesArray = request.Messages.ToArray();
var prePromptMessages = messagesArray.Take(messagesArray.Length-1).ToList();



// use the last message as input
string input;
if(prePromptMessages.Count>0)
input = request.Messages.Last().Content;
else
input = ConvertMessage(request.Messages.Last());

var watch = Stopwatch.StartNew();

IEnumerable<string> response;


var session = CreateSession(prePromptMessages);
response = session.Chat(input + "\n");
string buf = "";
foreach (var message in response)
{
buf += message;
if (_configuration.Mode == ELLamaSharpModelMode.Instruction)
{
if (buf.EndsWith("###"))
{
buf = buf.Substring(0, buf.Length - 3);
break;
}

}
}

var output = SanitizeOutput(buf);

var result = request.Messages.ToList();

switch (_configuration.Mode)
{
case ELLamaSharpModelMode.Chat:
result.Add(output.AsAiMessage());
break;
case ELLamaSharpModelMode.Instruction:
result.Add(output.AsSystemMessage());
break;
}


watch.Stop();

// Unsupported
var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
TotalUsage += usage;

return Task.FromResult(new ChatResponse(
Messages: result,
Usage: usage));
}

private static string SanitizeOutput(string output)
{
output = output.Replace("\nHuman:", "");
output = output.Replace("Assistant:", "");
output = output.Trim();
return output;
}

private ChatSession<LLamaModel> CreateSession(List<Message> preprompt)
{
var res = new ChatSession<LLamaModel>(_model);

if (_configuration.Mode == ELLamaSharpModelMode.Chat)
res = res.WithAntiprompt(new[] { "Human:" });

if (preprompt.Count > 0)
{
preprompt.Add("".AsHumanMessage());

var prompt = string.Join(
"\n", preprompt.Select(ConvertMessage).ToArray());

res = res
.WithPrompt(prompt);
}


return res;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net4.6.2;netstandard2.0;net6.0;net7.0</TargetFrameworks>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="LLamaSharp" />
<PackageReference Include="System.Net.Http" />
</ItemGroup>

<ItemGroup Label="Usings">
<Using Include="System.Net.Http" />
</ItemGroup>


<PropertyGroup Label="NuGet">
<Description>LLamaSharp Chat model provider.</Description>
<PackageTags>$(PackageTags);LLama;LLamaSharp;api</PackageTags>
</PropertyGroup>


<ItemGroup>
<ProjectReference Include="..\..\LangChain.Core\LangChain.Core.csproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using LangChain.Providers;
using LangChain.Providers.LLamaSharp;

namespace LangChain.Providers.LLamaSharp.IntegrationTests;

[TestClass]
public class LLamaSharpTests
{
[TestMethod]
#if CONTINUOUS_INTEGRATION_BUILD
[Ignore]
#endif
public void PrepromptTest()
{
var model = new LLamaSharpModel(new LLamaSharpConfiguration
{
PathToModelFile = Path.Combine(Environment.ExpandEnvironmentVariables("%LLAMA_MODELS%"), "ggml-model-f32-q4_0.bin"),
});

var response=model.GenerateAsync(new ChatRequest(new List<Message>
{
"You are simple assistant. If human say 'Bob' then you will respond with 'Jack'.".AsSystemMessage(),
"Bob".AsHumanMessage(),
"Jack".AsAiMessage(),
"Bob".AsHumanMessage(),
"Jack".AsAiMessage(),
"Bob".AsHumanMessage(),
})).Result;

Assert.AreEqual(response.Messages.Last().Content, "Jack");

}

[TestMethod]
#if CONTINUOUS_INTEGRATION_BUILD
[Ignore]
#endif
public void InstructionTest()
{
var model = new LLamaSharpModel(new LLamaSharpConfiguration
{
PathToModelFile = Path.Combine(Environment.ExpandEnvironmentVariables("%LLAMA_MODELS%"), "ggml-model-f32-q4_0.bin"),
Mode = ELLamaSharpModelMode.Instruction
});

var response=model.GenerateAsync(new ChatRequest(new List<Message>
{
"You are a calculator. You will be provided with expression. You must calculate it and print the result. Do not add any addition information.".AsSystemMessage(),
"2 + 2".AsSystemMessage(),
})).Result;

Assert.IsTrue(response.Messages.Last().Content.Trim().Equals("4"));

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net7.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>

<IsPackable>false</IsPackable>
<IsTestProject>true</IsTestProject>
<PlatformTarget>x64</PlatformTarget>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="LLamaSharp.Backend.Cpu" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\libs\Providers\LangChain.Providers.LLamaSharp\LangChain.Providers.LLamaSharp.csproj" />
</ItemGroup>



</Project>

0 comments on commit c776a41

Please sign in to comment.