■ 다중 클래스 분류 모델 사용하기

------------------------------------------------------------------------------------------------------------------------


TestProject.zip


IssueData.cs

 

 

using Microsoft.ML.Data;

 

namespace TestProject

{

    /// <summary>

    /// 이슈 데이터

    /// </summary>

    public class IssueData

    {

        //////////////////////////////////////////////////////////////////////////////////////////////////// Property

        ////////////////////////////////////////////////////////////////////////////////////////// Public

 

        #region ID - ID

 

        /// <summary>

        /// ID

        /// </summary>

        [LoadColumn(0)]

        public string ID { get; set; }

 

        #endregion

        #region 영역 - Area

 

        /// <summary>

        /// 영역

        /// </summary>

        [LoadColumn(1)]

        public string Area { get; set; }

 

        #endregion

        #region 제목 - Title

 

        /// <summary>

        /// 제목

        /// </summary>

        [LoadColumn(2)]

        public string Title { get; set; }

 

        #endregion

        #region 설명 - Description

 

        /// <summary>

        /// 설명

        /// </summary>

        [LoadColumn(3)]

        public string Description { get; set; }

 

        #endregion

    }

}

 

 

IssuePrediction.cs

 

 

using Microsoft.ML.Data;

 

namespace TestProject

{

    /// <summary>

    /// 이슈 예측

    /// </summary>

    public class IssuePrediction

    {

        //////////////////////////////////////////////////////////////////////////////////////////////////// Property

        ////////////////////////////////////////////////////////////////////////////////////////// Public

 

        #region 영역 - Area

 

        /// <summary>

        /// 영역

        /// </summary>

        [ColumnName("PredictedLabel")]

        public string Area;

 

        #endregion

    }

}

 

 

Program.cs

 

 

using System;

using System.IO;

 

using Microsoft.ML;

using Microsoft.ML.Data;

 

namespace TestProject

{

    /// <summary>

    /// 프로그램

    /// </summary>

    class Program

    {

        //////////////////////////////////////////////////////////////////////////////////////////////////// Field

        ////////////////////////////////////////////////////////////////////////////////////////// Static

        //////////////////////////////////////////////////////////////////////////////// Private

 

        #region Field

 

        /// <summary>

        /// 애플리케이션 경로

        /// </summary>

        private static string _applicationPath => Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]);

 

        /// <summary>

        /// 훈련 데이터 파일 경로

        /// </summary>

        private static string _trainingDataFilePath => Path.Combine(_applicationPath, "Data", "issues_train.tsv");

 

        /// <summary>

        /// 테스트 데이터 파일 경로

        /// </summary>

        private static string _testDataFilePath => Path.Combine(_applicationPath, "Data", "issues_test.tsv");

 

        /// <summary>

        /// 모델 파일 경로

        /// </summary>

        private static string _modelFilePath => Path.Combine(_applicationPath, "Data", "model.zip");

 

        /// <summary>

        /// 모델 컨텍스트

        /// </summary>

        private static MLContext _context;

 

        /// <summary>

        /// 예측 엔진

        /// </summary>

        private static PredictionEngine<IssueData, IssuePrediction> _predictionEngine;

 

        /// <summary>

        /// 모델

        /// </summary>

        private static ITransformer _model;

 

        /// <summary>

        /// 훈련 데이터 뷰

        /// </summary>

        private static IDataView _trainingDataView;

 

        #endregion

 

        //////////////////////////////////////////////////////////////////////////////////////////////////// Method

        ////////////////////////////////////////////////////////////////////////////////////////// Static

        //////////////////////////////////////////////////////////////////////////////// Private

 

        #region 프로그램 시작하기 - Main()

 

        /// <summary>

        /// 프로그램 시작하기

        /// </summary>

        private static void Main()

        {

            Console.WriteLine("BEGIN MAIN FUNCTION");

 

            _context = new MLContext(seed: 0);

 

            Console.WriteLine("BEGIN SET TRAINING DATA VIEW");

 

            _trainingDataView = _context.Data.LoadFromTextFile<IssueData>(_trainingDataFilePath, hasHeader: true);

 

            Console.WriteLine("END SET TRAINING DATA VIEW");

 

            Console.WriteLine("BEGIN SET PIPE LINE");

 

            var pipeline = GetPipeLine();

 

            Console.WriteLine("END SET PIPE LINE");

 

            _model = GetModel(_trainingDataView, pipeline);

 

            Console.WriteLine("BEGIN PREDICT SINGLE ITEM");

 

            _predictionEngine = _context.Model.CreatePredictionEngine<IssueData, IssuePrediction>(_model);

 

            IssueData issueData = new IssueData()

            {

                Title       = "WebSockets communication is slow in my machine",

                Description = "The WebSockets communication used under the covers by SignalR looks like is going slow " +

                              "in my development machine.."

            };

 

            IssuePrediction issuePrediction = _predictionEngine.Predict(issueData);

 

            Console.WriteLine("--------------------------------------------------");

            Console.WriteLine($"AREA : {issuePrediction.Area}");

            Console.WriteLine("--------------------------------------------------");

 

            Evaluate(_trainingDataView.Schema);

 

            SaveModel(_context, _trainingDataView.Schema, _model);

 

            PredictIssue();

 

            Console.WriteLine("END MAIN FUNCTION");

        }

 

        #endregion

 

        #region 파이프 라인 구하기 - GetPipeLine()

 

        /// <summary>

        /// 파이프 라인 구하기

        /// </summary>

        /// <returns>파이프 라인</returns>

        private static IEstimator<ITransformer> GetPipeLine()

        {

            Console.WriteLine("BEGIN GET PIPE LINE FUNCTION");

 

            var pipeline = _context.Transforms.Conversion.MapValueToKey(inputColumnName : "Area", outputColumnName : "Label")

                .Append(_context.Transforms.Text.FeaturizeText(inputColumnName : "Title"      , outputColumnName : "TitleFeaturized"      ))

                .Append(_context.Transforms.Text.FeaturizeText(inputColumnName : "Description", outputColumnName : "DescriptionFeaturized"))

                .Append(_context.Transforms.Concatenate("Features", "TitleFeaturized", "DescriptionFeaturized"))

                .AppendCacheCheckpoint(_context)

                .Append(_context.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "Features"))

                .Append(_context.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

 

            Console.WriteLine("END GET PIPE LINE FUNCTION");

 

            return pipeline;

        }

 

        #endregion

        #region 모델 구하기 - GetModel(trainingDataView, pipeline)

 

        /// <summary>

        /// 모델 구하기

        /// </summary>

        /// <param name="trainingDataView">훈련 데이터 뷰</param>

        /// <param name="pipeline">파이프 라인</param>

        /// <returns>모델</returns>

        private static ITransformer GetModel(IDataView trainingDataView, IEstimator<ITransformer> pipeline)

        {

            Console.WriteLine("BEGIN GET MODEL FUNCTION");

 

            var model = pipeline.Fit(trainingDataView);

 

            Console.WriteLine("END GET MODEL FUNCTION");

 

            return model;

        }

 

        #endregion

        #region 평가하기 - Evaluate(trainingDataViewSchema)

 

        /// <summary>

        /// 평가하기

        /// </summary>

        /// <param name="trainingDataViewSchema">훈련 데이터 뷰 스키마</param>

        private static void Evaluate(DataViewSchema trainingDataViewSchema)

        {

            Console.WriteLine("BEGIN EVALUATE FUNCTION");

 

            IDataView testDataView = _context.Data.LoadFromTextFile<IssueData>(_testDataFilePath, hasHeader : true);

 

            MulticlassClassificationMetrics metrics = _context.MulticlassClassification.Evaluate(_model.Transform(testDataView));

 

            Console.WriteLine("--------------------------------------------------"      );

            Console.WriteLine("METRICS FOR MULTI-CLASS CLASSIFICATION MODEL - TEST DATA");

            Console.WriteLine("--------------------------------------------------"      );

            Console.WriteLine($"MICRO ACCURACY     : {metrics.MicroAccuracy:0.###}"     );

            Console.WriteLine($"MACRO ACCURACY     : {metrics.MacroAccuracy:0.###}"     );

            Console.WriteLine($"LOG LOSS           : {metrics.LogLoss:#.###}"           );

            Console.WriteLine($"LOG LOSS REDUCTION : {metrics.LogLossReduction:#.###}"  );

            Console.WriteLine("--------------------------------------------------"      );

 

            Console.WriteLine("END EVALUATE FUNCTION");

        }

 

        #endregion

        #region 모델 저장하기 - SaveModel(context, trainingDataViewSchema, model)

 

        /// <summary>

        /// 모델 저장하기

        /// </summary>

        /// <param name="context">ML 컨텍스트</param>

        /// <param name="trainingDataViewSchema">훈련 데이터 뷰 스키마</param>

        /// <param name="model">모델</param>

        private static void SaveModel(MLContext context, DataViewSchema trainingDataViewSchema, ITransformer model)

        {

            Console.WriteLine("BEGIN SAVE MODEL FUNCTION");

 

            context.Model.Save(model, trainingDataViewSchema, _modelFilePath);

 

            Console.WriteLine("END SAVE MODEL FUNCTION");

        }

 

        #endregion

        #region 이슈 예측하기 - PredictIssue()

 

        /// <summary>

        /// 이슈 예측하기

        /// </summary>

        private static void PredictIssue()

        {

            Console.WriteLine("BEGIN PREDICT ISSUE FUNCTION");

 

            ITransformer model = _context.Model.Load(_modelFilePath, out var inputSchema);

 

            IssueData issueData = new IssueData()

            {

                Title       = "Entity Framework crashes",

                Description = "When connecting to the database, EF is crashing"

            };

 

            _predictionEngine = _context.Model.CreatePredictionEngine<IssueData, IssuePrediction>(model);

 

            IssuePrediction issuePrediction = _predictionEngine.Predict(issueData);

 

            Console.WriteLine("--------------------------------------------------");

            Console.WriteLine($"AREA : {issuePrediction.Area}");

            Console.WriteLine("--------------------------------------------------");

 

            Console.WriteLine("END PREDICT ISSUE FUNCTION");

        }

 

        #endregion

    }

}

 

------------------------------------------------------------------------------------------------------------------------

Posted by 사용자 icodebroker