unit RAG.Database;

interface

uses
  System.SysUtils, System.Classes, UniProvider, PostgreSQLUniProvider, Data.DB,
  DBAccess, Uni, MemDS, System.Rtti,

  DAScript, UniScript;

type
  TDocumentChunk = record
    Tag: string;
    Content: string;
    Embedding: TArray<Extended>;
  end;

  TMainDataModule = class(TDataModule)
    MainConnection: TUniConnection;
    PostgreSQLUniProvider1: TPostgreSQLUniProvider;
    QryGetEmbeddings: TUniQuery;
    dsDocuments: TUniQuery;
    dsInsertDocument: TUniQuery;
    dsInsertChunk: TUniQuery;
    ScriptEmbedding: TUniScript;
  private
    FProvider: string;
    FModelId: string;
    function GetTableName: string;
    function FloatArrayToStr(const AValues: TArray<Extended>): string;
    function GetNextVal(const AGenerator: string): Integer;
  public
    function GetRelevantDocsFromDB(const AQueryEmbedding: TArray<Extended>; ADomain:Integer; ANumberOfChunks: Integer): TArray<TDocumentChunk>;
    function AddDocument(const AFileName, ASummary: string): Integer;
    procedure DeleteDocument(AId: Integer);
    procedure SaveChunks(const AFileName: string; const AChunks: TArray<TDocumentChunk>; ADocumentId: Integer);
    procedure DeleteChunks(ADocumentId: Integer);
    procedure UpdateChunksDomain(ADocumentId, ADomainId: Integer);
    procedure CheckEmbeddingTable;

    procedure Commit;
    procedure Rollback;

    property Provider: string read FProvider write FProvider;
    property ModelId: string read FModelId write FModelId;

    constructor Create(AOwner: TComponent); override;
    destructor Destroy; override;
  end;

implementation

const
  // Should be reader from the embedder
  VectorSize = 768;

{$R *.dfm}

{ TMainDataModule }

function TMainDataModule.AddDocument(const AFileName, ASummary: string): Integer;
begin
  Result := GetNextVal('seq_documenti');
//  dsInsertDocument.ParamByName('filename').AsString := AFileName;
//  dsInsertDocument.ParamByName('tag').AsString := 'tag';
//  dsInsertDocument.ParamByName('summary').AsString := ASummary;
//  dsInsertDocument.ParamByName('provider').AsString := LowerCase(FProvider);
//  dsInsertDocument.ParamByName('model').AsString := FModelId;
//  dsInsertDocument.Open;
//  try
//    if dsInsertDocument.Eof then
//      raise Exception.Create('Non riesco ad inserire il documento');
//    Result := dsInsertDocument.Fields[0].AsInteger;
//  finally
//    dsInsertDocument.Close;
//  end;
end;

procedure TMainDataModule.CheckEmbeddingTable;
begin
  ScriptEmbedding.MacroByName('vector_size').Value := IntToStr(VectorSize);
  ScriptEmbedding.Execute;
end;

procedure TMainDataModule.Commit;
begin
  if MainConnection.InTransaction then
    MainConnection.Commit;
end;

constructor TMainDataModule.Create(AOwner: TComponent);
begin
  inherited;
  MainConnection.ConnectString := GetEnvironmentVariable('RAG_CONNECTION');
  MainConnection.Connect;
end;

procedure TMainDataModule.DeleteChunks(ADocumentId: Integer);
begin
  MainConnection.ExecSQL('DELETE FROM ' + GetTableName +' WHERE DOCUMENT_ID = ' + IntToStr(ADocumentId));
end;

procedure TMainDataModule.DeleteDocument(AId: Integer);
begin
  MainConnection.ExecSQL('DELETE FROM DOCUMENTI WHERE DOC_ID = ' + IntToStr(AId));
end;

destructor TMainDataModule.Destroy;
begin
  MainConnection.Disconnect;
  inherited;
end;

function TMainDataModule.FloatArrayToStr(
  const AValues: TArray<Extended>): string;
var
  LStringValues: TArray<string>;
  I: Integer;
  FS: TFormatSettings;
begin  FS := TFormatSettings.Invariant;
  SetLength(LStringValues, Length(AValues));
  for I := 0 to Length(AValues) - 1 do
  begin
    LStringValues[I] := FloatToStr(AValues[I], FS);
  end;
  Result := string.join(',', LStringValues);
end;

function TMainDataModule.GetNextVal(const AGenerator: string): Integer;
var
  UniQuery: TUniQuery;
begin
  UniQuery := TUniQuery.Create(nil);
  try
    UniQuery.Connection := MainConnection;
    UniQuery.SQL.Text := 'select nextval(' + QuotedStr(AGenerator) + ')';
    UniQuery.Execute;
    if UniQuery.IsEmpty then
      raise Exception.CreateFmt('Errore lettura sequence', [AGenerator]);
    Result := UniQuery.Fields[0].AsInteger;
  finally
    UniQuery.Free;
  end;
end;

//function TMainDataModule.GetQuery(const ASQL: string): IQuery;
//var
//  UniQuery: TUniQuery;
//begin
//  UniQuery := TUniQuery.Create(nil);
//  UniQuery.Connection := MainConnection;
//  UniQuery.SQL.Text := ASQL;
//  Result := TEDMUniQuery.Create(UniQuery, True);
//end;
//
//function TMainDataModule.GetQuery(ADataSet: TDataSet): IQuery;
//begin
//  if ADataSet is TUniQuery then
//    Result := TEDMUniQuery.Create(TUniQuery(ADataSet))
//  else
//    raise Exception.Create('DataSet non supportato');
//end;

function TMainDataModule.GetRelevantDocsFromDB(const AQueryEmbedding: TArray<Extended>;
  ADomain:Integer; ANumberOfChunks: Integer): TArray<TDocumentChunk>;
var
  I: Integer;
begin
  Result := [];
  QryGetEmbeddings.ParamByName('provider').AsString :=  LowerCase(FProvider);
  QryGetEmbeddings.ParamByName('model').AsString := FModelId;
  QryGetEmbeddings.ParamByName('vector').AsString := '[' + FloatArrayToStr(AQueryEmbedding) + ']';
  QryGetEmbeddings.ParamByName('domain_id').AsInteger := ADomain;
  QryGetEmbeddings.MacroByName('table_name').Value := GetTableName;
  QryGetEmbeddings.MacroByName('limit').Value := IntToStr(ANumberOfChunks);
  QryGetEmbeddings.Open;
  try
    // Forza il fetch dell'intero dataset
    QryGetEmbeddings.Last;
    SetLength(Result, QryGetEmbeddings.RecordCount);
    QryGetEmbeddings.First;
    I := 0;
    while not QryGetEmbeddings.Eof do
    begin
      Result[I].Tag := QryGetEmbeddings.FieldByName('tag').AsString;
      Result[I].Content := QryGetEmbeddings.FieldByName('document').AsString;
      Inc(I);
      QryGetEmbeddings.Next;
    end;
  finally
    QryGetEmbeddings.Close;
  end;
end;

function TMainDataModule.GetTableName: string;
begin
  Result := 'rag_data_' + IntToStr(VectorSize);
end;

procedure TMainDataModule.Rollback;
begin
  if MainConnection.InTransaction then
    MainConnection.Rollback;
end;

procedure TMainDataModule.SaveChunks(const AFileName: string;
  const AChunks: TArray<TDocumentChunk>; ADocumentId: Integer);
var
  LChunk: TDocumentChunk;
begin
  //:tag, :document, :embedding, :provider, :model, :document_id
  dsInsertChunk.MacroByName('table_name').Value := GetTableName;

  for LChunk in AChunks do
  begin
    dsInsertChunk.ParamByName('tag').AsString := LChunk.Tag;
    dsInsertChunk.ParamByName('document').AsString := LChunk.Content;
    dsInsertChunk.ParamByName('embedding').AsString := '[' + FloatArrayToStr(LChunk.Embedding) + ']';
    dsInsertChunk.ParamByName('provider').AsString := LowerCase(Provider);
    dsInsertChunk.ParamByName('model').AsString := ModelId;
    dsInsertChunk.ParamByName('document_id').AsInteger := ADocumentId;
    dsInsertChunk.Execute;
  end;
end;

procedure TMainDataModule.UpdateChunksDomain(ADocumentId, ADomainId: Integer);
begin
  MainConnection.ExecSQL('UPDATE ' + GetTableName + ' SET DOMAIN_ID = ' + IntToStr(ADomainId) + ' WHERE DOCUMENT_ID = ' + IntToStr(ADocumentId));
end;

end.
